File size: 6,940 Bytes
5c39bf9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """End-to-end integration test for ACO.
Demonstrates the full pipeline:
1. Classify task type and difficulty
2. Route to appropriate model tier
3. Budget context
4. Decide tool usage
5. Decide verification
6. Handle failure/retry
7. Detect doom
8. Collect telemetry
"""
import json, sys, os
# Add parent to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from aco.classifier import TaskCostClassifier
from aco.context_budgeter import ContextBudgeter
from aco.tool_gate import ToolUseGate
from aco.verifier_budgeter import VerifierBudgeter
from aco.retry_optimizer import RetryOptimizer
from aco.doom_detector import DoomDetector
from aco.telemetry import TelemetryCollector
from aco.cache_layout import CacheAwareLayout
from aco.meta_tool_miner import MetaToolMiner
from aco.trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall
from aco.conformal import ConformalEscalationCalibrator
from aco.pareto import build_frontier_report
def test_full_pipeline():
"""Run the full ACO pipeline on sample tasks."""
print("="*60)
print("ACO END-TO-END INTEGRATION TEST")
print("="*60)
# Initialize all modules
classifier = TaskCostClassifier()
budgeter = ContextBudgeter()
tool_gate = ToolUseGate()
verifier = VerifierBudgeter()
retry = RetryOptimizer()
doom = DoomDetector()
telemetry = TelemetryCollector()
cache = CacheAwareLayout()
miner = MetaToolMiner()
conformal = ConformalEscalationCalibrator(alpha=0.05)
# Sample tasks
tasks = [
"Fix the typo in the README",
"Implement user authentication with JWT tokens",
"Research the latest advances in transformer architectures",
"Review this contract for GDPR compliance issues",
"Deploy the new microservice to production",
"Write a Python function to sort a list",
]
results = []
for task in tasks:
print(f"\n--- Task: {task[:50]}... ---")
# 1. Classify
cls = classifier.classify(task)
print(f" Type: {cls['task_type']}, Difficulty: {cls['difficulty']}")
print(f" Risk: {cls['risk_level']}, Needs retrieval: {cls['needs_retrieval']}")
print(f" Needs verifier: {cls['needs_verifier']}, Est cost: ${cls['estimated_cost']:.2f}")
# 2. Route (using classifier's recommended tier)
tier = cls['recommended_tier']
print(f" Routed to tier: {tier}")
# 3. Conformal check
psuccess = 0.7 + 0.05 * tier # placeholder
should_esc = conformal.should_escalate(tier, psuccess)
if should_esc:
tier = min(tier + 1, 5)
print(f" Conformal escalation → tier {tier}")
# 4. Budget context
ctx = budgeter.budget(task, cls['task_type'], tier)
print(f" Context: {ctx.total_tokens} tokens ({len(ctx.keep)} keep, {len(ctx.summarize)} summarize)")
# 5. Cache layout
prompt = cache.layout(task, ctx)
print(f" Cache: {prompt.prefix_tokens} prefix (cacheable), {prompt.suffix_tokens} suffix (dynamic)")
# 6. Tool gate
tools = tool_gate.evaluate(task, cls['task_type'], tier)
print(f" Tools: {len(tools.recommended)} recommended, {len(tools.skipped)} skipped, batch={tools.should_batch}")
# 7. Verifier
v = verifier.should_verify(task, cls['task_type'], tier, psuccess)
print(f" Verifier: {v.should_verify} (risk={v.risk_level}, reason={v.reason})")
# 8. Collect telemetry
step = TraceStep(
step_type="route",
model_call=ModelCall(tier=tier, psuccess=psuccess),
tool_calls=[ToolCall(name=t) for t in tools.recommended[:3]],
)
telemetry.record(task, cls, step)
results.append({
"task": task,
"type": cls['task_type'],
"difficulty": cls['difficulty'],
"tier": tier,
"context_tokens": ctx.total_tokens,
"cache_prefix": prompt.prefix_tokens,
"tools_recommended": len(tools.recommended),
"verifier_needed": v.should_verify,
})
# 9. Test doom detector
print(f"\n--- Doom Detector Test ---")
from aco.doom_detector import RunState
state = RunState(
failed_tool_calls=4,
total_cost=2.50,
artifacts_created=0,
retries=3,
context_tokens=80000,
)
doom_result = doom.check(state)
print(f" State: {state.failed_tool_calls} failed tools, {state.total_cost} cost, {state.artifacts_created} artifacts")
print(f" Doom: {doom_result.is_doomed}, Score: {doom_result.doom_score:.2f}")
print(f" Action: {doom_result.recommended_action}")
# 10. Test retry optimizer
print(f"\n--- Retry Optimizer Test ---")
recovery = retry.recommend("tool_call_failed", "SyntaxError in generated code")
print(f" Failure: tool_call_failed")
print(f" Recovery: {recovery.action} (strategy={recovery.strategy})")
recovery2 = retry.recommend("model_overloaded", "Rate limit exceeded")
print(f" Failure: model_overloaded")
print(f" Recovery: {recovery2.action}")
# 11. Test meta-tool miner
print(f"\n--- Meta-Tool Miner Test ---")
from aco.meta_tool_miner import TracePattern
pattern = TracePattern(
name="search_read_patch_test",
steps=["search", "read_file", "edit_file", "run_test"],
frequency=47,
success_rate=0.82,
)
macro = miner.compress(pattern)
print(f" Pattern: {pattern.name} (freq={pattern.frequency}, sr={pattern.success_rate})")
print(f" Macro: {macro.name}, steps={macro.steps}, savings={macro.estimated_savings:.1f} LLM calls")
# 12. Test Pareto frontier
print(f"\n--- Pareto Frontier Test ---")
policies = {
"oracle": (0.870, 0.062),
"v10_feedback": (0.848, 0.201),
"frontier": (0.782, 0.317),
"v10_direct": (0.766, 0.188),
"always_cheap": (0.632, 0.014),
}
report = build_frontier_report(policies)
print(f" AIQ: {report['aiq']}")
for level, cost in report['frontier_quality_levels'].items():
print(f" {level}: ${cost:.3f}")
# Summary
print(f"\n{'='*60}")
print("INTEGRATION TEST COMPLETE")
print(f"{'='*60}")
print(f"Processed {len(results)} tasks")
tier_dist = {}
for r in results:
tier_dist[r['tier']] = tier_dist.get(r['tier'], 0) + 1
print(f"Tier distribution: {tier_dist}")
verified = sum(1 for r in results if r['verifier_needed'])
print(f"Verifier calls: {verified}/{len(results)} ({verified/len(results)*100:.0f}%)")
avg_ctx = sum(r['context_tokens'] for r in results) / len(results)
print(f"Average context: {avg_ctx:.0f} tokens")
return results
if __name__ == "__main__":
test_full_pipeline()
|