narcolepticchicken commited on
Commit
d4a573d
Β·
verified Β·
1 Parent(s): 8275b23

Upload train_router_gen.py

Browse files
Files changed (1) hide show
  1. train_router_gen.py +126 -0
train_router_gen.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ─── Synthetic Trace Generator (for training data) ────────────────
3
+ MODEL_CONFIGS = {
4
+ "tiny_local": {"tier":1,"cost_input":0.0001,"cost_output":0.0002,"latency":200,"strength":0.35},
5
+ "cheap_cloud": {"tier":2,"cost_input":0.00015,"cost_output":0.0006,"latency":400,"strength":0.55},
6
+ "medium": {"tier":3,"cost_input":0.0015,"cost_output":0.006,"latency":800,"strength":0.80},
7
+ "frontier": {"tier":4,"cost_input":0.005,"cost_output":0.015,"latency":1500,"strength":0.93},
8
+ "specialist": {"tier":5,"cost_input":0.01,"cost_output":0.03,"latency":2000,"strength":0.97},
9
+ }
10
+ TIER_TO_MODEL = {1:"tiny_local",2:"cheap_cloud",3:"medium",4:"frontier",5:"specialist"}
11
+ TIER_COST_MULT = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
12
+
13
+ TASK_TEMPLATES = {
14
+ "quick_answer": [
15
+ "What is the capital of France?","Explain quantum computing briefly.",
16
+ "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?",
17
+ "What is the speed of light?","List the primary colors.",
18
+ "What is GDP?","When was the Declaration of Independence signed?",
19
+ ],
20
+ "coding": [
21
+ "Write a Python function to reverse a linked list.",
22
+ "Fix the bug in this React component.","Refactor auth module to JWT.",
23
+ "Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
24
+ "Add unit tests for the payment module.","Optimize this SQL query.",
25
+ "Create a REST API for user management.","Implement binary search in Rust.",
26
+ "Write a recursive descent parser for JSON.",
27
+ ],
28
+ "research": [
29
+ "Research latest transformer advances.",
30
+ "Find sources comparing LoRA and full FT.",
31
+ "Investigate data center climate impact.",
32
+ "What does literature say on speculative decoding?",
33
+ "Survey privacy-preserving ML techniques.",
34
+ "Compare reinforcement learning algorithms for robotics.",
35
+ "Analyze trends in LLM scaling laws.",
36
+ ],
37
+ "document_drafting": [
38
+ "Draft project proposal for ML pipeline.",
39
+ "Write email to team about deployment.",
40
+ "Create technical report on performance.",
41
+ "Write a project brief for the migration.",
42
+ "Draft meeting notes summary.",
43
+ ],
44
+ "legal_regulated": [
45
+ "Review this contract for liability clauses.",
46
+ "Check GDPR compliance for data pipeline.",
47
+ "Draft privacy policy section.",
48
+ "Analyze indemnification clause in vendor agreement.",
49
+ "Verify regulatory compliance for medical device software.",
50
+ ],
51
+ "tool_heavy": [
52
+ "Search open issues and create summary.",
53
+ "Fetch API docs and generate client code.",
54
+ "Query Q3 sales and produce chart.",
55
+ "Aggregate metrics from 5 monitoring endpoints.",
56
+ ],
57
+ "retrieval_heavy": [
58
+ "Answer based on 50-page document.",
59
+ "Find all payment processing mentions.",
60
+ "Retrieve relevant cases for legal query.",
61
+ "Summarize the quarterly earnings report.",
62
+ ],
63
+ "long_horizon": [
64
+ "Plan 3-month roadmap.","Orchestrate multi-region deployment.",
65
+ "Redesign data architecture end-to-end.","Migrate monolith to microservices.",
66
+ ],
67
+ "unknown_ambiguous": [
68
+ "Help me with this thing.","I need something about the server.",
69
+ "Can you look into that issue?","There's a problem with the data.",
70
+ ],
71
+ }
72
+
73
+ def tier_success_prob(tier, difficulty):
74
+ strength = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}.get(tier,0.5)
75
+ return strength ** (difficulty * 0.6)
76
+
77
+ def generate_training_trace(idx, rng):
78
+ task_types = list(TASK_TEMPLATES.keys())
79
+ task_type = rng.choice(task_types)
80
+ difficulty = {
81
+ "quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
82
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5,
83
+ }[task_type]
84
+
85
+ # Try ALL tiers for this task to get ground truth
86
+ tier_outcomes = {}
87
+ for tier in range(1, 6):
88
+ sp = tier_success_prob(tier, difficulty)
89
+ tier_outcomes[tier] = rng.random() < sp
90
+
91
+ optimal_tier = 5 # default: need strongest
92
+ for tier in range(1, 6):
93
+ if tier_outcomes.get(tier, False):
94
+ optimal_tier = tier
95
+ break
96
+
97
+ actual_tier = rng.choice(list(range(1, 6)))
98
+ # Bias toward reasonable tiers
99
+ if difficulty <= 2:
100
+ actual_tier = rng.choices([1,2,3,4,5], weights=[3,4,2,1,0.5])[0]
101
+ elif difficulty == 3:
102
+ actual_tier = rng.choices([1,2,3,4,5], weights=[1,2,4,2,1])[0]
103
+ elif difficulty == 4:
104
+ actual_tier = rng.choices([1,2,3,4,5], weights=[0.5,1,2,4,2])[0]
105
+ else:
106
+ actual_tier = rng.choices([1,2,3,4,5], weights=[0.2,0.5,1,3,4])[0]
107
+
108
+ outcome = "success" if tier_outcomes.get(actual_tier, False) else "failure"
109
+
110
+ user_request = rng.choice(TASK_TEMPLATES[task_type])
111
+ cost_mult = TIER_COST_MULT[actual_tier]
112
+ base_tokens = rng.randint(800, 12000) + rng.randint(200, 6000)
113
+ cost = base_tokens / 1000 * MODEL_CONFIGS[TIER_TO_MODEL[actual_tier]]["cost_input"] * cost_mult
114
+
115
+ return {
116
+ "trace_id": f"train_{idx}",
117
+ "user_request": user_request,
118
+ "task_type": task_type,
119
+ "difficulty": difficulty,
120
+ "actual_tier": actual_tier,
121
+ "optimal_tier": optimal_tier,
122
+ "outcome": outcome,
123
+ "cost": cost,
124
+ "tier_outcomes": {str(k): v for k, v in tier_outcomes.items()},
125
+ "metadata": {"difficulty": difficulty, "optimal_tier": optimal_tier, "actual_tier": actual_tier},
126
+ }