Threat_Hunter / agents /critic.py
EricChen2005's picture
Deploy ThreatHunter - AMD MI300X + Qwen2.5-32B
c8d30bc
"""
agents/critic.py - Critic Agent (Devil's Advocate)
Harness Layers:
Layer 1 - ENABLE_CRITIC=false -> SKIPPED
Layer 1b - Fallback when LLM unparseable
Layer 2 - Schema validation
Layer 2' - Deep scorecard repair (fixes shallow-merge defect)
Layer 3 - weighted_score type safety + verdict enum (fixes TypeError defect)
"""
import json, logging, os, re, time
from datetime import datetime, timezone
from typing import Any
from crewai import Agent, Task
from config import ENABLE_CRITIC, MAX_DEBATE_ROUNDS, get_llm
from tools.kev_tool import check_cisa_kev
from tools.exploit_tool import search_exploits
from tools.memory_tool import read_memory
logger = logging.getLogger("ThreatHunter.Critic")
CONSTITUTION = """
=== ThreatHunter Constitution ===
1. All CVE IDs must come from Tool-returned data. Fabrication is prohibited.
2. You must use the provided Tools for queries. Skip is not allowed.
3. Output must conform to the specified JSON schema.
4. Uncertain reasoning must be tagged with confidence: HIGH / MEDIUM / NEEDS_VERIFICATION.
5. Each judgment must include a reasoning field.
6. Reports use English; technical terms are not translated.
7. Do not call the same Tool twice for the same data.
"""
def _load_skill(skill_filename: str = "debate_sop.md") -> str:
skill_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"skills",
skill_filename,
)
try:
with open(skill_path, "r", encoding="utf-8") as skill_file:
return skill_file.read()
except FileNotFoundError:
return "## Skill: Debate SOP\nChallenge Analyst assumptions. Use tools to verify."
# v3.7: Path-Aware Skill Map(對應 main.py recorder.stage_enter 使用)
SKILL_MAP: dict[str, str] = {
"pkg": "debate_sop.md", # Path A: package debate
"code": "code_debate_sop.md", # Path B-code: source code debate
"injection": "ai_debate_sop.md", # Path B-inject: AI security debate
"config": "config_debate_sop.md", # Path C: config debate
}
VALID_VERDICTS = {"MAINTAIN", "DOWNGRADE", "SKIPPED"}
SCORECARD_FIELDS = ["evidence", "chain_completeness", "critique_quality", "defense_quality", "calibration"]
WEIGHTS = {"evidence": 0.30, "chain_completeness": 0.25, "critique_quality": 0.20, "defense_quality": 0.15, "calibration": 0.10}
def create_critic_agent(
excluded_models: list[str] | None = None,
input_type: str = "pkg",
) -> Agent:
"""Build the Critic Agent (Devil's Advocate).
Args:
excluded_models: Models to skip (429 rate-limited models)
"""
skill_filename = SKILL_MAP.get(input_type, "debate_sop.md")
skill_content = _load_skill(skill_filename)
return Agent(
role="Security Debate Advisor (Critic / Devil's Advocate)",
goal=(
"Challenge Analyst Agent results via adversarial debate. "
"Validate every prerequisite with tools, detect overconfidence, "
"output a 5-dimensional scorecard and verdict (MAINTAIN / DOWNGRADE / SKIPPED)."
),
backstory=f"""You are a rigorous Red Team Analyst.
{CONSTITUTION}
## Debate SOP (from skills/{skill_filename})
{skill_content}
## Output Specification (Critic Data Contract)
Output ONLY the following JSON, no text outside it:
```json
{{
"debate_rounds": 2,
"challenges": ["Challenge 1: description (English)"],
"scorecard": {{
"evidence": 0.85, "chain_completeness": 0.80,
"critique_quality": 0.75, "defense_quality": 0.70, "calibration": 0.90
}},
"weighted_score": 80.5,
"verdict": "MAINTAIN",
"reasoning": "One sentence verdict rationale (English)",
"generated_at": "ISO 8601 timestamp"
}}
```
## Verdict Rules
- weighted_score >= 70 -> verdict: "MAINTAIN"
- 50 <= score < 70 -> verdict: "MAINTAIN" (with challenge notes)
- score < 50 -> verdict: "DOWNGRADE"
## Prohibited Actions
- Do NOT downgrade a CVE with in_cisa_kev=true
- Do NOT conclude without calling at least one tool
""",
tools=[check_cisa_kev, search_exploits, read_memory],
llm=get_llm(exclude_models=excluded_models), # lazy init: 只在建立 Agent 時才呼叫
verbose=True,
max_iter=8,
allow_delegation=False,
)
def create_critic_task(agent: Agent, analyst_output: str) -> Task:
"""Build Critic Task."""
return Task(
description=f"""
You are the Devil's Advocate. Analyst Agent result:
{analyst_output}
Steps (max {MAX_DEBATE_ROUNDS} rounds):
1. For chain_risk.is_chain=true: call check_cisa_kev + search_exploits
2. For confidence=HIGH with low tool coverage: detect overconfidence
3. Calculate 5D scorecard (evidence/chain_completeness/critique_quality/defense_quality/calibration)
4. Determine verdict: MAINTAIN or DOWNGRADE
5. Output complete JSON debate report
Note: if any CVE has in_cisa_kev=true, DOWNGRADE is prohibited.
""",
expected_output="Complete JSON debate report with debate_rounds, challenges, scorecard, weighted_score, verdict, reasoning.",
agent=agent,
)
def _extract_json_from_output(raw: str) -> dict[str, Any]:
"""Extract JSON from LLM output (tolerates Markdown wrapping)."""
if not raw or not isinstance(raw, str):
return {}
try:
return json.loads(raw)
except (json.JSONDecodeError, ValueError):
pass
match = re.search(r"```(?:json)?\s*([\s\S]+?)```", raw)
if match:
try:
return json.loads(match.group(1).strip())
except (json.JSONDecodeError, ValueError):
pass
match = re.search(r"\{[\s\S]+\}", raw)
if match:
try:
return json.loads(match.group(0))
except (json.JSONDecodeError, ValueError):
pass
return {}
def _build_skipped_output(reason: str = "ENABLE_CRITIC=false") -> dict[str, Any]:
"""Harness Layer 1: Build SKIPPED output."""
return {
"debate_rounds": 0, "challenges": [],
"scorecard": {field: 1.0 for field in SCORECARD_FIELDS},
"weighted_score": 100.0, "verdict": "SKIPPED",
"reasoning": reason,
"generated_at": datetime.now(timezone.utc).isoformat(),
"_harness_skipped": True,
}
def _compute_weighted_score(scorecard: dict[str, Any]) -> float:
"""
Calculate 5D weighted score (0-100). Type-safe: float() conversion per field.
FIX: Invalid types fall back to 0.5 instead of raising TypeError.
"""
total = 0.0
for field, weight in WEIGHTS.items():
raw = scorecard.get(field, 0.5)
try:
val = float(raw)
except (TypeError, ValueError):
logger.warning("Critic scorecard.%s = %r invalid type, using 0.5", field, raw)
val = 0.5
total += max(0.0, min(1.0, val)) * weight
return round(min(100.0, max(0.0, total * 100)), 2)
def _build_fallback_output(analyst_data: dict[str, Any]) -> dict[str, Any]:
"""Harness guarantee: minimum viable debate report when LLM output unparseable."""
analysis = analyst_data.get("analysis", analyst_data.get("vulnerabilities", []))
round_no = int(analyst_data.get("_debate_round", 1) or 1)
max_rounds = int(analyst_data.get("_debate_max_rounds", MAX_DEBATE_ROUNDS) or MAX_DEBATE_ROUNDS)
challenges = [
f"Challenge: {item.get('cve_id') or item.get('cwe_id') or 'Evidence item'} chain prerequisites not fully verified."
for item in analysis
if item.get("chain_risk", {}).get("is_chain") and item.get("chain_risk", {}).get("confidence") == "HIGH"
] or ["No specific challenges raised (fallback mode)."]
scorecard = {"evidence": 0.6, "chain_completeness": 0.5, "critique_quality": 0.6, "defense_quality": 0.7, "calibration": 0.7}
weighted_score = _compute_weighted_score(scorecard)
return {
"debate_rounds": round_no, "challenges": challenges, "scorecard": scorecard,
"weighted_score": weighted_score,
"verdict": "MAINTAIN" if weighted_score >= 50 else "DOWNGRADE",
"reasoning": "Fallback evaluation: LLM output unavailable. Conservative scoring applied.",
"generated_at": datetime.now(timezone.utc).isoformat(),
"_harness_fallback": True,
"_critic_round": round_no,
"_max_rounds": max_rounds,
}
def _harness_validate_schema(output: dict[str, Any]) -> list[str]:
"""Harness Layer 2: Validate output format (data_contracts.md Critic section)."""
errors = []
for k in ["debate_rounds", "challenges", "scorecard", "weighted_score", "verdict"]:
if k not in output:
errors.append(f"Missing required field: {k}")
scorecard = output.get("scorecard", {})
if not isinstance(scorecard, dict):
errors.append("scorecard must be an object")
return errors
for field in SCORECARD_FIELDS:
if field not in scorecard:
errors.append(f"scorecard missing field: {field}")
else:
try:
val = float(scorecard[field])
if not (0.0 <= val <= 1.0):
errors.append(f"scorecard.{field} out of range ({val})")
except (TypeError, ValueError):
errors.append(f"scorecard.{field} invalid type: {scorecard[field]!r}")
return errors
def _harness_repair_scorecard(output: dict[str, Any]) -> None:
"""
Harness Layer 2': Deep scorecard repair.
FIX for stress test defect: 'Shallow merge defect - 2,500 cases with missing
calibration field escaped because Layer 2 only does top-level key comparison.'
"""
scorecard = output.get("scorecard")
if not isinstance(scorecard, dict):
output["scorecard"] = {field: 0.6 for field in SCORECARD_FIELDS}
logger.warning("Harness Layer 2': scorecard entirely missing, building safe defaults")
return
for field in SCORECARD_FIELDS:
if field not in scorecard:
scorecard[field] = 0.6
logger.warning("Harness Layer 2': scorecard missing sub-field %s, auto-patched to 0.6", field)
else:
try:
scorecard[field] = max(0.0, min(1.0, float(scorecard[field])))
except (TypeError, ValueError):
logger.warning("Harness Layer 2': scorecard.%s invalid type (%r), reset to 0.5", field, scorecard[field])
scorecard[field] = 0.5
def _harness_validate_verdict(output: dict[str, Any]) -> None:
"""
Harness Layer 3: verdict enum + weighted_score type safety.
FIX for stress test defect: '392 TypeError cases - str >= int comparison failure
when LLM returns string as weighted_score. Root fix: force float FIRST.'
"""
# ROOT FIX: force weighted_score to float before any comparison
raw_ws = output.get("weighted_score", 50.0)
try:
output["weighted_score"] = float(raw_ws)
except (TypeError, ValueError):
logger.warning("Harness Layer 3: weighted_score invalid (%r), forcing to 50.0", raw_ws)
output["weighted_score"] = 50.0
# Correct verdict enum
if output.get("verdict", "") not in VALID_VERDICTS:
logger.warning("Harness Layer 3: illegal verdict=%s, forcing MAINTAIN", output.get("verdict"))
output["verdict"] = "MAINTAIN"
# Recompute from scorecard (prevent LLM arithmetic errors)
scorecard = output.get("scorecard", {})
if scorecard:
recalculated = _compute_weighted_score(scorecard)
original = output["weighted_score"] # guaranteed float
if abs(recalculated - original) > 5.0:
logger.warning("Harness Layer 3: weighted_score drift (%.2f vs %.2f), using recalculated", original, recalculated)
output["weighted_score"] = recalculated
if output["verdict"] != "SKIPPED":
output["verdict"] = "MAINTAIN" if output["weighted_score"] >= 50 else "DOWNGRADE"
def _normalize_cwe_set(raw: Any) -> set[str]:
if raw is None:
return set()
if isinstance(raw, str):
return {match.upper() for match in re.findall(r"\bCWE-\d+\b", raw, flags=re.IGNORECASE)}
if isinstance(raw, dict):
values: set[str] = set()
for value in raw.values():
values.update(_normalize_cwe_set(value))
return values
if isinstance(raw, (list, tuple, set)):
values: set[str] = set()
for item in raw:
values.update(_normalize_cwe_set(item))
return values
return set()
def _collect_observed_cwe_categories(analyst_data: dict[str, Any], context: dict[str, Any]) -> set[str]:
explicit = context.get("observed_cwe_categories") or analyst_data.get("observed_cwe_categories")
if explicit:
return _normalize_cwe_set(explicit)
observed: set[str] = set()
for key in (
"analysis",
"vulnerabilities",
"findings",
"code_findings",
"security_findings",
"code_patterns",
"code_patterns_summary",
):
observed.update(_normalize_cwe_set(analyst_data.get(key)))
return observed
def _build_recall_challenge(analyst_data: dict[str, Any]) -> dict[str, Any] | None:
context = (
analyst_data.get("benchmark_context")
or analyst_data.get("dim11_benchmark")
or analyst_data.get("recall_gate")
or {}
)
if not isinstance(context, dict):
context = {}
expected = _normalize_cwe_set(
context.get("expected_cwe_categories") or analyst_data.get("expected_cwe_categories")
)
if not expected:
return None
observed = _collect_observed_cwe_categories(analyst_data, context)
matched = expected & observed
missing = expected - observed
category_recall = len(matched) / len(expected)
try:
min_recall = float(context.get("min_category_recall", context.get("min_recall", 0.7)))
except (TypeError, ValueError):
min_recall = 0.7
min_recall = max(0.0, min(1.0, min_recall))
route_correct = bool(context.get("route_correct", analyst_data.get("route_correct", True)))
try:
pollution_count = int(context.get("external_pollution_count", analyst_data.get("external_pollution_count", 0)) or 0)
except (TypeError, ValueError):
pollution_count = 0
failed_reasons = []
if category_recall < min_recall:
failed_reasons.append("category_recall_below_threshold")
if not route_correct:
failed_reasons.append("route_incorrect")
if pollution_count > 0:
failed_reasons.append("external_evidence_pollution")
return {
"fixture": context.get("fixture") or analyst_data.get("fixture"),
"expected_cwe_categories": sorted(expected),
"observed_cwe_categories": sorted(observed),
"missing_cwe_categories": sorted(missing),
"category_recall": round(category_recall, 4),
"min_category_recall": min_recall,
"route_correct": route_correct,
"external_pollution_count": pollution_count,
"verdict": "FAIL" if failed_reasons else "PASS",
"failed_reasons": failed_reasons,
}
def _apply_recall_challenge(output: dict[str, Any], analyst_data: dict[str, Any]) -> None:
challenge = _build_recall_challenge(analyst_data)
if not challenge:
return
output["recall_challenge"] = challenge
scorecard = output.setdefault("scorecard", {})
scorecard["category_recall"] = challenge["category_recall"]
scorecard["route_correctness"] = 1.0 if challenge["route_correct"] else 0.0
scorecard["evidence_pollution"] = 1.0 if challenge["external_pollution_count"] == 0 else 0.0
if challenge["verdict"] == "PASS":
return
message = (
"Deterministic recall challenge failed: "
f"missing={challenge['missing_cwe_categories']}, "
f"recall={challenge['category_recall']:.2f}, "
f"min={challenge['min_category_recall']:.2f}, "
f"route_correct={challenge['route_correct']}, "
f"external_pollution_count={challenge['external_pollution_count']}."
)
challenges = output.setdefault("challenges", [])
if not isinstance(challenges, list):
challenges = [str(challenges)]
output["challenges"] = challenges
challenges.append(message)
output["verdict"] = "DOWNGRADE"
try:
output["weighted_score"] = min(float(output.get("weighted_score", 50.0)), 49.0)
except (TypeError, ValueError):
output["weighted_score"] = 49.0
reasoning = output.get("reasoning") or ""
output["reasoning"] = f"{reasoning} {message}".strip()
output["needs_rescan"] = True
def run_critic_pipeline(analyst_output: str | dict[str, Any], input_type: str = "pkg") -> dict[str, Any]:
"""Execute Critic Agent Pipeline (Harness Layers 1/1b/2/2'/3).
Args:
analyst_output: Analyst 輸出 JSON
input_type: Path-Aware Skill 路由(pkg/code/injection/config)
"""
from crewai import Crew, Process
if isinstance(analyst_output, dict):
analyst_dict = analyst_output
analyst_str = json.dumps(analyst_output, ensure_ascii=False, indent=2)
else:
analyst_str = str(analyst_output) if analyst_output else ""
try:
analyst_dict = json.loads(analyst_str)
except (json.JSONDecodeError, ValueError):
analyst_dict = {}
if not ENABLE_CRITIC:
logger.info("Harness Layer 1: ENABLE_CRITIC=false, skipping")
output = _build_skipped_output()
_apply_recall_challenge(output, analyst_dict)
return output
logger.info("Critic Pipeline started (max rounds: %d)", MAX_DEBATE_ROUNDS)
# 429 自動輪替:最多重試 MAX_LLM_RETRIES 次(每次切換模型)
from config import mark_model_failed, get_current_model_name
MAX_LLM_RETRIES = 2
excluded_models: list[str] = []
output: dict[str, Any] = {}
crew_success = False
for attempt in range(MAX_LLM_RETRIES + 1):
agent = create_critic_agent(excluded_models, input_type=input_type)
task = create_critic_task(agent, analyst_str)
try:
crew = Crew(agents=[agent], tasks=[task], process=Process.sequential, verbose=True)
logger.info("Critic Crew kickoff (attempt %d/%d)", attempt + 1, MAX_LLM_RETRIES + 1)
try:
from checkpoint import recorder as _cp
_c_model = get_current_model_name(agent.llm)
_cp.llm_call("critic", _c_model, "openrouter", f"attempt={attempt+1}")
except Exception:
_c_model = "unknown"
_t_c = time.time()
result = crew.kickoff()
raw_output = str(result.raw) if hasattr(result, "raw") else str(result)
try:
_cp.llm_result("critic", _c_model, "SUCCESS",
len(raw_output), int((time.time() - _t_c) * 1000),
thinking=raw_output[:1000])
except Exception:
pass
output = _extract_json_from_output(raw_output)
crew_success = bool(output)
break # 成功則跳出
except Exception as e:
error_str = str(e)
if "429" in error_str and attempt < MAX_LLM_RETRIES:
current_model = get_current_model_name(agent.llm)
mark_model_failed(current_model)
excluded_models.append(current_model)
import re as _re
_m = _re.search(r'retry.{1,10}(\d+\.?\d*)s', error_str, _re.IGNORECASE)
retry_after = float(_m.group(1)) if _m else 0.0
logger.warning("[RETRY] Critic 429 on %s (attempt %d/%d), api_retry_after=%.0fs",
current_model, attempt + 1, MAX_LLM_RETRIES, retry_after)
try:
_cp.llm_retry("critic", current_model, error_str[:200],
attempt + 1, "next_in_waterfall")
except Exception:
pass
from config import rate_limiter as _rl
_rl.on_429(retry_after=retry_after, caller="critic") # 最少 30s
continue
logger.error("Critic CrewAI failed: %s", e)
try:
_cp.llm_error("critic", _c_model, error_str[:300])
except Exception:
pass
if not crew_success or not output:
logger.warning("Harness Layer 1b: LLM output invalid, using fallback")
output = _build_fallback_output(analyst_dict)
schema_errors = _harness_validate_schema(output)
if schema_errors:
logger.warning("Harness Layer 2: Schema errors %s, merging fallback", schema_errors)
fallback = _build_fallback_output(analyst_dict)
for k, v in fallback.items():
if k not in output:
output[k] = v
_harness_repair_scorecard(output) # Layer 2'
_harness_validate_verdict(output) # Layer 3
_apply_recall_challenge(output, analyst_dict)
round_no = int(analyst_dict.get("_debate_round", output.get("debate_rounds", 1)) or 1)
max_rounds = int(analyst_dict.get("_debate_max_rounds", MAX_DEBATE_ROUNDS) or MAX_DEBATE_ROUNDS)
output["debate_rounds"] = max(int(output.get("debate_rounds", 0) or 0), round_no)
output["_critic_round"] = round_no
output["_max_rounds"] = max_rounds
if "generated_at" not in output:
output["generated_at"] = datetime.now(timezone.utc).isoformat()
logger.info("Critic Pipeline done | verdict=%s | score=%.1f | challenges=%d",
output.get("verdict"), output.get("weighted_score", 0), len(output.get("challenges", [])))
return output
if __name__ == "__main__":
import sys
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s")
_test = json.dumps({"analysis": [{"cve_id": "CVE-2024-42005", "original_cvss": 9.8, "chain_risk": {"is_chain": True, "confidence": "HIGH"}}]})
result = run_critic_pipeline(_test)
print(json.dumps(result, ensure_ascii=False, indent=2))