Upload train_router_gen.py
Browse files- train_router_gen.py +126 -0
train_router_gen.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# βββ Synthetic Trace Generator (for training data) ββββββββββββββββ
|
| 3 |
+
MODEL_CONFIGS = {
|
| 4 |
+
"tiny_local": {"tier":1,"cost_input":0.0001,"cost_output":0.0002,"latency":200,"strength":0.35},
|
| 5 |
+
"cheap_cloud": {"tier":2,"cost_input":0.00015,"cost_output":0.0006,"latency":400,"strength":0.55},
|
| 6 |
+
"medium": {"tier":3,"cost_input":0.0015,"cost_output":0.006,"latency":800,"strength":0.80},
|
| 7 |
+
"frontier": {"tier":4,"cost_input":0.005,"cost_output":0.015,"latency":1500,"strength":0.93},
|
| 8 |
+
"specialist": {"tier":5,"cost_input":0.01,"cost_output":0.03,"latency":2000,"strength":0.97},
|
| 9 |
+
}
|
| 10 |
+
TIER_TO_MODEL = {1:"tiny_local",2:"cheap_cloud",3:"medium",4:"frontier",5:"specialist"}
|
| 11 |
+
TIER_COST_MULT = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
|
| 12 |
+
|
| 13 |
+
TASK_TEMPLATES = {
|
| 14 |
+
"quick_answer": [
|
| 15 |
+
"What is the capital of France?","Explain quantum computing briefly.",
|
| 16 |
+
"What is 237*452?","Define photosynthesis.","Who wrote Hamlet?",
|
| 17 |
+
"What is the speed of light?","List the primary colors.",
|
| 18 |
+
"What is GDP?","When was the Declaration of Independence signed?",
|
| 19 |
+
],
|
| 20 |
+
"coding": [
|
| 21 |
+
"Write a Python function to reverse a linked list.",
|
| 22 |
+
"Fix the bug in this React component.","Refactor auth module to JWT.",
|
| 23 |
+
"Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
|
| 24 |
+
"Add unit tests for the payment module.","Optimize this SQL query.",
|
| 25 |
+
"Create a REST API for user management.","Implement binary search in Rust.",
|
| 26 |
+
"Write a recursive descent parser for JSON.",
|
| 27 |
+
],
|
| 28 |
+
"research": [
|
| 29 |
+
"Research latest transformer advances.",
|
| 30 |
+
"Find sources comparing LoRA and full FT.",
|
| 31 |
+
"Investigate data center climate impact.",
|
| 32 |
+
"What does literature say on speculative decoding?",
|
| 33 |
+
"Survey privacy-preserving ML techniques.",
|
| 34 |
+
"Compare reinforcement learning algorithms for robotics.",
|
| 35 |
+
"Analyze trends in LLM scaling laws.",
|
| 36 |
+
],
|
| 37 |
+
"document_drafting": [
|
| 38 |
+
"Draft project proposal for ML pipeline.",
|
| 39 |
+
"Write email to team about deployment.",
|
| 40 |
+
"Create technical report on performance.",
|
| 41 |
+
"Write a project brief for the migration.",
|
| 42 |
+
"Draft meeting notes summary.",
|
| 43 |
+
],
|
| 44 |
+
"legal_regulated": [
|
| 45 |
+
"Review this contract for liability clauses.",
|
| 46 |
+
"Check GDPR compliance for data pipeline.",
|
| 47 |
+
"Draft privacy policy section.",
|
| 48 |
+
"Analyze indemnification clause in vendor agreement.",
|
| 49 |
+
"Verify regulatory compliance for medical device software.",
|
| 50 |
+
],
|
| 51 |
+
"tool_heavy": [
|
| 52 |
+
"Search open issues and create summary.",
|
| 53 |
+
"Fetch API docs and generate client code.",
|
| 54 |
+
"Query Q3 sales and produce chart.",
|
| 55 |
+
"Aggregate metrics from 5 monitoring endpoints.",
|
| 56 |
+
],
|
| 57 |
+
"retrieval_heavy": [
|
| 58 |
+
"Answer based on 50-page document.",
|
| 59 |
+
"Find all payment processing mentions.",
|
| 60 |
+
"Retrieve relevant cases for legal query.",
|
| 61 |
+
"Summarize the quarterly earnings report.",
|
| 62 |
+
],
|
| 63 |
+
"long_horizon": [
|
| 64 |
+
"Plan 3-month roadmap.","Orchestrate multi-region deployment.",
|
| 65 |
+
"Redesign data architecture end-to-end.","Migrate monolith to microservices.",
|
| 66 |
+
],
|
| 67 |
+
"unknown_ambiguous": [
|
| 68 |
+
"Help me with this thing.","I need something about the server.",
|
| 69 |
+
"Can you look into that issue?","There's a problem with the data.",
|
| 70 |
+
],
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
def tier_success_prob(tier, difficulty):
|
| 74 |
+
strength = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}.get(tier,0.5)
|
| 75 |
+
return strength ** (difficulty * 0.6)
|
| 76 |
+
|
| 77 |
+
def generate_training_trace(idx, rng):
|
| 78 |
+
task_types = list(TASK_TEMPLATES.keys())
|
| 79 |
+
task_type = rng.choice(task_types)
|
| 80 |
+
difficulty = {
|
| 81 |
+
"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
|
| 82 |
+
"research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5,
|
| 83 |
+
}[task_type]
|
| 84 |
+
|
| 85 |
+
# Try ALL tiers for this task to get ground truth
|
| 86 |
+
tier_outcomes = {}
|
| 87 |
+
for tier in range(1, 6):
|
| 88 |
+
sp = tier_success_prob(tier, difficulty)
|
| 89 |
+
tier_outcomes[tier] = rng.random() < sp
|
| 90 |
+
|
| 91 |
+
optimal_tier = 5 # default: need strongest
|
| 92 |
+
for tier in range(1, 6):
|
| 93 |
+
if tier_outcomes.get(tier, False):
|
| 94 |
+
optimal_tier = tier
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
actual_tier = rng.choice(list(range(1, 6)))
|
| 98 |
+
# Bias toward reasonable tiers
|
| 99 |
+
if difficulty <= 2:
|
| 100 |
+
actual_tier = rng.choices([1,2,3,4,5], weights=[3,4,2,1,0.5])[0]
|
| 101 |
+
elif difficulty == 3:
|
| 102 |
+
actual_tier = rng.choices([1,2,3,4,5], weights=[1,2,4,2,1])[0]
|
| 103 |
+
elif difficulty == 4:
|
| 104 |
+
actual_tier = rng.choices([1,2,3,4,5], weights=[0.5,1,2,4,2])[0]
|
| 105 |
+
else:
|
| 106 |
+
actual_tier = rng.choices([1,2,3,4,5], weights=[0.2,0.5,1,3,4])[0]
|
| 107 |
+
|
| 108 |
+
outcome = "success" if tier_outcomes.get(actual_tier, False) else "failure"
|
| 109 |
+
|
| 110 |
+
user_request = rng.choice(TASK_TEMPLATES[task_type])
|
| 111 |
+
cost_mult = TIER_COST_MULT[actual_tier]
|
| 112 |
+
base_tokens = rng.randint(800, 12000) + rng.randint(200, 6000)
|
| 113 |
+
cost = base_tokens / 1000 * MODEL_CONFIGS[TIER_TO_MODEL[actual_tier]]["cost_input"] * cost_mult
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"trace_id": f"train_{idx}",
|
| 117 |
+
"user_request": user_request,
|
| 118 |
+
"task_type": task_type,
|
| 119 |
+
"difficulty": difficulty,
|
| 120 |
+
"actual_tier": actual_tier,
|
| 121 |
+
"optimal_tier": optimal_tier,
|
| 122 |
+
"outcome": outcome,
|
| 123 |
+
"cost": cost,
|
| 124 |
+
"tier_outcomes": {str(k): v for k, v in tier_outcomes.items()},
|
| 125 |
+
"metadata": {"difficulty": difficulty, "optimal_tier": optimal_tier, "actual_tier": actual_tier},
|
| 126 |
+
}
|