Rohan03's picture
Sprint 8: quorum.py β€” consensus/disagreement topology switches + critic ensemble
361d29c verified
"""
quorum.py β€” Consensus and disagreement-driven topology switches.
When multiple agents produce outputs:
- Agreement β†’ merge outputs confidently
- Disagreement β†’ escalate to critic ensemble or HITL
- Critical risk β†’ require human approval
Critic ensemble personas:
- Correctness critic
- Safety/security critic
- Cost/latency critic
- User-alignment critic
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from purpose_agent.llm_backend import LLMBackend, ChatMessage
logger = logging.getLogger(__name__)
class QuorumDecision:
MERGE = "merge"
ESCALATE = "escalate"
HITL = "hitl"
REJECT = "reject"
@dataclass
class QuorumConfig:
agreement_threshold: float = 0.7
disagreement_threshold: float = 0.4
critical_risk_keywords: list[str] = field(default_factory=lambda: ["delete","drop","remove","destroy","sudo","admin"])
min_votes: int = 2
@dataclass
class CriticVerdict:
critic_name: str
score: float
reasoning: str
flags: list[str] = field(default_factory=list)
class QuorumCoordinator:
"""
Decides topology based on agent agreement/disagreement.
Usage:
qc = QuorumCoordinator(config=QuorumConfig())
decision = qc.evaluate(outputs=["answer_a", "answer_b", "answer_c"])
if decision == QuorumDecision.MERGE: ...
elif decision == QuorumDecision.ESCALATE: ...
"""
def __init__(self, config: QuorumConfig | None = None):
self.config = config or QuorumConfig()
def evaluate(self, outputs: list[str], task: str = "") -> str:
if len(outputs) < self.config.min_votes:
return QuorumDecision.MERGE
# Check critical risk
combined = " ".join(outputs).lower()
if any(kw in combined for kw in self.config.critical_risk_keywords):
return QuorumDecision.HITL
# Measure agreement (simple: how many outputs share common content)
agreement = self._measure_agreement(outputs)
if agreement >= self.config.agreement_threshold:
return QuorumDecision.MERGE
elif agreement <= self.config.disagreement_threshold:
return QuorumDecision.ESCALATE
return QuorumDecision.MERGE
def _measure_agreement(self, outputs: list[str]) -> float:
if len(outputs) <= 1: return 1.0
# Simple word-overlap agreement metric
word_sets = [set(o.lower().split()) for o in outputs]
if not word_sets: return 0.0
common = word_sets[0]
for ws in word_sets[1:]: common = common & ws
total = set()
for ws in word_sets: total = total | ws
return len(common) / max(len(total), 1)
class CriticEnsemble:
"""
Ensemble of specialized critics for multi-perspective evaluation.
Usage:
ensemble = CriticEnsemble(llm=backend)
verdicts = ensemble.evaluate(output="agent's response", task="original task")
avg_score = ensemble.aggregate(verdicts)
"""
CRITICS = [
("correctness", "Is the output factually correct and complete?"),
("safety", "Does the output contain unsafe, harmful, or policy-violating content?"),
("efficiency", "Is the output concise and cost-effective?"),
("alignment", "Does the output align with the user's stated purpose?"),
]
def __init__(self, llm: LLMBackend | None = None):
self.llm = llm
self._history: list[list[CriticVerdict]] = []
def evaluate(self, output: str, task: str = "") -> list[CriticVerdict]:
verdicts = []
for name, question in self.CRITICS:
if self.llm:
verdict = self._llm_evaluate(name, question, output, task)
else:
verdict = CriticVerdict(critic_name=name, score=0.5, reasoning="No LLM available")
verdicts.append(verdict)
self._history.append(verdicts)
return verdicts
def aggregate(self, verdicts: list[CriticVerdict]) -> float:
if not verdicts: return 0.0
return sum(v.score for v in verdicts) / len(verdicts)
def _llm_evaluate(self, name: str, question: str, output: str, task: str) -> CriticVerdict:
prompt = f"Task: {task}\nOutput: {output[:500]}\n\nQuestion: {question}\nScore 0-10 and explain briefly."
try:
from purpose_agent.robust_parser import extract_number
raw = self.llm.generate([ChatMessage(role="user", content=prompt)], temperature=0.2, max_tokens=300)
score = extract_number(raw, "score", 5.0) / 10.0
return CriticVerdict(critic_name=name, score=min(1.0, max(0.0, score)), reasoning=raw[:200])
except:
return CriticVerdict(critic_name=name, score=0.5, reasoning="Evaluation failed")