narcolepticchicken commited on
Commit
5c39bf9
·
verified ·
1 Parent(s): 0087553

Upload tests/test_integration.py

Browse files
Files changed (1) hide show
  1. tests/test_integration.py +187 -0
tests/test_integration.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end integration test for ACO.
2
+
3
+ Demonstrates the full pipeline:
4
+ 1. Classify task type and difficulty
5
+ 2. Route to appropriate model tier
6
+ 3. Budget context
7
+ 4. Decide tool usage
8
+ 5. Decide verification
9
+ 6. Handle failure/retry
10
+ 7. Detect doom
11
+ 8. Collect telemetry
12
+ """
13
+ import json, sys, os
14
+
15
+ # Add parent to path
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ from aco.classifier import TaskCostClassifier
19
+ from aco.context_budgeter import ContextBudgeter
20
+ from aco.tool_gate import ToolUseGate
21
+ from aco.verifier_budgeter import VerifierBudgeter
22
+ from aco.retry_optimizer import RetryOptimizer
23
+ from aco.doom_detector import DoomDetector
24
+ from aco.telemetry import TelemetryCollector
25
+ from aco.cache_layout import CacheAwareLayout
26
+ from aco.meta_tool_miner import MetaToolMiner
27
+ from aco.trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall
28
+ from aco.conformal import ConformalEscalationCalibrator
29
+ from aco.pareto import build_frontier_report
30
+
31
+
32
+ def test_full_pipeline():
33
+ """Run the full ACO pipeline on sample tasks."""
34
+ print("="*60)
35
+ print("ACO END-TO-END INTEGRATION TEST")
36
+ print("="*60)
37
+
38
+ # Initialize all modules
39
+ classifier = TaskCostClassifier()
40
+ budgeter = ContextBudgeter()
41
+ tool_gate = ToolUseGate()
42
+ verifier = VerifierBudgeter()
43
+ retry = RetryOptimizer()
44
+ doom = DoomDetector()
45
+ telemetry = TelemetryCollector()
46
+ cache = CacheAwareLayout()
47
+ miner = MetaToolMiner()
48
+ conformal = ConformalEscalationCalibrator(alpha=0.05)
49
+
50
+ # Sample tasks
51
+ tasks = [
52
+ "Fix the typo in the README",
53
+ "Implement user authentication with JWT tokens",
54
+ "Research the latest advances in transformer architectures",
55
+ "Review this contract for GDPR compliance issues",
56
+ "Deploy the new microservice to production",
57
+ "Write a Python function to sort a list",
58
+ ]
59
+
60
+ results = []
61
+
62
+ for task in tasks:
63
+ print(f"\n--- Task: {task[:50]}... ---")
64
+
65
+ # 1. Classify
66
+ cls = classifier.classify(task)
67
+ print(f" Type: {cls['task_type']}, Difficulty: {cls['difficulty']}")
68
+ print(f" Risk: {cls['risk_level']}, Needs retrieval: {cls['needs_retrieval']}")
69
+ print(f" Needs verifier: {cls['needs_verifier']}, Est cost: ${cls['estimated_cost']:.2f}")
70
+
71
+ # 2. Route (using classifier's recommended tier)
72
+ tier = cls['recommended_tier']
73
+ print(f" Routed to tier: {tier}")
74
+
75
+ # 3. Conformal check
76
+ psuccess = 0.7 + 0.05 * tier # placeholder
77
+ should_esc = conformal.should_escalate(tier, psuccess)
78
+ if should_esc:
79
+ tier = min(tier + 1, 5)
80
+ print(f" Conformal escalation → tier {tier}")
81
+
82
+ # 4. Budget context
83
+ ctx = budgeter.budget(task, cls['task_type'], tier)
84
+ print(f" Context: {ctx.total_tokens} tokens ({len(ctx.keep)} keep, {len(ctx.summarize)} summarize)")
85
+
86
+ # 5. Cache layout
87
+ prompt = cache.layout(task, ctx)
88
+ print(f" Cache: {prompt.prefix_tokens} prefix (cacheable), {prompt.suffix_tokens} suffix (dynamic)")
89
+
90
+ # 6. Tool gate
91
+ tools = tool_gate.evaluate(task, cls['task_type'], tier)
92
+ print(f" Tools: {len(tools.recommended)} recommended, {len(tools.skipped)} skipped, batch={tools.should_batch}")
93
+
94
+ # 7. Verifier
95
+ v = verifier.should_verify(task, cls['task_type'], tier, psuccess)
96
+ print(f" Verifier: {v.should_verify} (risk={v.risk_level}, reason={v.reason})")
97
+
98
+ # 8. Collect telemetry
99
+ step = TraceStep(
100
+ step_type="route",
101
+ model_call=ModelCall(tier=tier, psuccess=psuccess),
102
+ tool_calls=[ToolCall(name=t) for t in tools.recommended[:3]],
103
+ )
104
+ telemetry.record(task, cls, step)
105
+
106
+ results.append({
107
+ "task": task,
108
+ "type": cls['task_type'],
109
+ "difficulty": cls['difficulty'],
110
+ "tier": tier,
111
+ "context_tokens": ctx.total_tokens,
112
+ "cache_prefix": prompt.prefix_tokens,
113
+ "tools_recommended": len(tools.recommended),
114
+ "verifier_needed": v.should_verify,
115
+ })
116
+
117
+ # 9. Test doom detector
118
+ print(f"\n--- Doom Detector Test ---")
119
+ from aco.doom_detector import RunState
120
+ state = RunState(
121
+ failed_tool_calls=4,
122
+ total_cost=2.50,
123
+ artifacts_created=0,
124
+ retries=3,
125
+ context_tokens=80000,
126
+ )
127
+ doom_result = doom.check(state)
128
+ print(f" State: {state.failed_tool_calls} failed tools, {state.total_cost} cost, {state.artifacts_created} artifacts")
129
+ print(f" Doom: {doom_result.is_doomed}, Score: {doom_result.doom_score:.2f}")
130
+ print(f" Action: {doom_result.recommended_action}")
131
+
132
+ # 10. Test retry optimizer
133
+ print(f"\n--- Retry Optimizer Test ---")
134
+ recovery = retry.recommend("tool_call_failed", "SyntaxError in generated code")
135
+ print(f" Failure: tool_call_failed")
136
+ print(f" Recovery: {recovery.action} (strategy={recovery.strategy})")
137
+
138
+ recovery2 = retry.recommend("model_overloaded", "Rate limit exceeded")
139
+ print(f" Failure: model_overloaded")
140
+ print(f" Recovery: {recovery2.action}")
141
+
142
+ # 11. Test meta-tool miner
143
+ print(f"\n--- Meta-Tool Miner Test ---")
144
+ from aco.meta_tool_miner import TracePattern
145
+ pattern = TracePattern(
146
+ name="search_read_patch_test",
147
+ steps=["search", "read_file", "edit_file", "run_test"],
148
+ frequency=47,
149
+ success_rate=0.82,
150
+ )
151
+ macro = miner.compress(pattern)
152
+ print(f" Pattern: {pattern.name} (freq={pattern.frequency}, sr={pattern.success_rate})")
153
+ print(f" Macro: {macro.name}, steps={macro.steps}, savings={macro.estimated_savings:.1f} LLM calls")
154
+
155
+ # 12. Test Pareto frontier
156
+ print(f"\n--- Pareto Frontier Test ---")
157
+ policies = {
158
+ "oracle": (0.870, 0.062),
159
+ "v10_feedback": (0.848, 0.201),
160
+ "frontier": (0.782, 0.317),
161
+ "v10_direct": (0.766, 0.188),
162
+ "always_cheap": (0.632, 0.014),
163
+ }
164
+ report = build_frontier_report(policies)
165
+ print(f" AIQ: {report['aiq']}")
166
+ for level, cost in report['frontier_quality_levels'].items():
167
+ print(f" {level}: ${cost:.3f}")
168
+
169
+ # Summary
170
+ print(f"\n{'='*60}")
171
+ print("INTEGRATION TEST COMPLETE")
172
+ print(f"{'='*60}")
173
+ print(f"Processed {len(results)} tasks")
174
+ tier_dist = {}
175
+ for r in results:
176
+ tier_dist[r['tier']] = tier_dist.get(r['tier'], 0) + 1
177
+ print(f"Tier distribution: {tier_dist}")
178
+ verified = sum(1 for r in results if r['verifier_needed'])
179
+ print(f"Verifier calls: {verified}/{len(results)} ({verified/len(results)*100:.0f}%)")
180
+ avg_ctx = sum(r['context_tokens'] for r in results) / len(results)
181
+ print(f"Average context: {avg_ctx:.0f} tokens")
182
+
183
+ return results
184
+
185
+
186
+ if __name__ == "__main__":
187
+ test_full_pipeline()