| """End-to-end integration test for ACO. |
| |
| Demonstrates the full pipeline: |
| 1. Classify task type and difficulty |
| 2. Route to appropriate model tier |
| 3. Budget context |
| 4. Decide tool usage |
| 5. Decide verification |
| 6. Handle failure/retry |
| 7. Detect doom |
| 8. Collect telemetry |
| """ |
| import json, sys, os |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from aco.classifier import TaskCostClassifier |
| from aco.context_budgeter import ContextBudgeter |
| from aco.tool_gate import ToolUseGate |
| from aco.verifier_budgeter import VerifierBudgeter |
| from aco.retry_optimizer import RetryOptimizer |
| from aco.doom_detector import DoomDetector |
| from aco.telemetry import TelemetryCollector |
| from aco.cache_layout import CacheAwareLayout |
| from aco.meta_tool_miner import MetaToolMiner |
| from aco.trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall |
| from aco.conformal import ConformalEscalationCalibrator |
| from aco.pareto import build_frontier_report |
|
|
|
|
| def test_full_pipeline(): |
| """Run the full ACO pipeline on sample tasks.""" |
| print("="*60) |
| print("ACO END-TO-END INTEGRATION TEST") |
| print("="*60) |
| |
| |
| classifier = TaskCostClassifier() |
| budgeter = ContextBudgeter() |
| tool_gate = ToolUseGate() |
| verifier = VerifierBudgeter() |
| retry = RetryOptimizer() |
| doom = DoomDetector() |
| telemetry = TelemetryCollector() |
| cache = CacheAwareLayout() |
| miner = MetaToolMiner() |
| conformal = ConformalEscalationCalibrator(alpha=0.05) |
| |
| |
| tasks = [ |
| "Fix the typo in the README", |
| "Implement user authentication with JWT tokens", |
| "Research the latest advances in transformer architectures", |
| "Review this contract for GDPR compliance issues", |
| "Deploy the new microservice to production", |
| "Write a Python function to sort a list", |
| ] |
| |
| results = [] |
| |
| for task in tasks: |
| print(f"\n--- Task: {task[:50]}... ---") |
| |
| |
| cls = classifier.classify(task) |
| print(f" Type: {cls['task_type']}, Difficulty: {cls['difficulty']}") |
| print(f" Risk: {cls['risk_level']}, Needs retrieval: {cls['needs_retrieval']}") |
| print(f" Needs verifier: {cls['needs_verifier']}, Est cost: ${cls['estimated_cost']:.2f}") |
| |
| |
| tier = cls['recommended_tier'] |
| print(f" Routed to tier: {tier}") |
| |
| |
| psuccess = 0.7 + 0.05 * tier |
| should_esc = conformal.should_escalate(tier, psuccess) |
| if should_esc: |
| tier = min(tier + 1, 5) |
| print(f" Conformal escalation → tier {tier}") |
| |
| |
| ctx = budgeter.budget(task, cls['task_type'], tier) |
| print(f" Context: {ctx.total_tokens} tokens ({len(ctx.keep)} keep, {len(ctx.summarize)} summarize)") |
| |
| |
| prompt = cache.layout(task, ctx) |
| print(f" Cache: {prompt.prefix_tokens} prefix (cacheable), {prompt.suffix_tokens} suffix (dynamic)") |
| |
| |
| tools = tool_gate.evaluate(task, cls['task_type'], tier) |
| print(f" Tools: {len(tools.recommended)} recommended, {len(tools.skipped)} skipped, batch={tools.should_batch}") |
| |
| |
| v = verifier.should_verify(task, cls['task_type'], tier, psuccess) |
| print(f" Verifier: {v.should_verify} (risk={v.risk_level}, reason={v.reason})") |
| |
| |
| step = TraceStep( |
| step_type="route", |
| model_call=ModelCall(tier=tier, psuccess=psuccess), |
| tool_calls=[ToolCall(name=t) for t in tools.recommended[:3]], |
| ) |
| telemetry.record(task, cls, step) |
| |
| results.append({ |
| "task": task, |
| "type": cls['task_type'], |
| "difficulty": cls['difficulty'], |
| "tier": tier, |
| "context_tokens": ctx.total_tokens, |
| "cache_prefix": prompt.prefix_tokens, |
| "tools_recommended": len(tools.recommended), |
| "verifier_needed": v.should_verify, |
| }) |
| |
| |
| print(f"\n--- Doom Detector Test ---") |
| from aco.doom_detector import RunState |
| state = RunState( |
| failed_tool_calls=4, |
| total_cost=2.50, |
| artifacts_created=0, |
| retries=3, |
| context_tokens=80000, |
| ) |
| doom_result = doom.check(state) |
| print(f" State: {state.failed_tool_calls} failed tools, {state.total_cost} cost, {state.artifacts_created} artifacts") |
| print(f" Doom: {doom_result.is_doomed}, Score: {doom_result.doom_score:.2f}") |
| print(f" Action: {doom_result.recommended_action}") |
| |
| |
| print(f"\n--- Retry Optimizer Test ---") |
| recovery = retry.recommend("tool_call_failed", "SyntaxError in generated code") |
| print(f" Failure: tool_call_failed") |
| print(f" Recovery: {recovery.action} (strategy={recovery.strategy})") |
| |
| recovery2 = retry.recommend("model_overloaded", "Rate limit exceeded") |
| print(f" Failure: model_overloaded") |
| print(f" Recovery: {recovery2.action}") |
| |
| |
| print(f"\n--- Meta-Tool Miner Test ---") |
| from aco.meta_tool_miner import TracePattern |
| pattern = TracePattern( |
| name="search_read_patch_test", |
| steps=["search", "read_file", "edit_file", "run_test"], |
| frequency=47, |
| success_rate=0.82, |
| ) |
| macro = miner.compress(pattern) |
| print(f" Pattern: {pattern.name} (freq={pattern.frequency}, sr={pattern.success_rate})") |
| print(f" Macro: {macro.name}, steps={macro.steps}, savings={macro.estimated_savings:.1f} LLM calls") |
| |
| |
| print(f"\n--- Pareto Frontier Test ---") |
| policies = { |
| "oracle": (0.870, 0.062), |
| "v10_feedback": (0.848, 0.201), |
| "frontier": (0.782, 0.317), |
| "v10_direct": (0.766, 0.188), |
| "always_cheap": (0.632, 0.014), |
| } |
| report = build_frontier_report(policies) |
| print(f" AIQ: {report['aiq']}") |
| for level, cost in report['frontier_quality_levels'].items(): |
| print(f" {level}: ${cost:.3f}") |
| |
| |
| print(f"\n{'='*60}") |
| print("INTEGRATION TEST COMPLETE") |
| print(f"{'='*60}") |
| print(f"Processed {len(results)} tasks") |
| tier_dist = {} |
| for r in results: |
| tier_dist[r['tier']] = tier_dist.get(r['tier'], 0) + 1 |
| print(f"Tier distribution: {tier_dist}") |
| verified = sum(1 for r in results if r['verifier_needed']) |
| print(f"Verifier calls: {verified}/{len(results)} ({verified/len(results)*100:.0f}%)") |
| avg_ctx = sum(r['context_tokens'] for r in results) / len(results) |
| print(f"Average context: {avg_ctx:.0f} tokens") |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| test_full_pipeline() |
|
|