|
|
| |
| MODEL_CONFIGS = { |
| "tiny_local": {"tier":1,"cost_input":0.0001,"cost_output":0.0002,"latency":200,"strength":0.35}, |
| "cheap_cloud": {"tier":2,"cost_input":0.00015,"cost_output":0.0006,"latency":400,"strength":0.55}, |
| "medium": {"tier":3,"cost_input":0.0015,"cost_output":0.006,"latency":800,"strength":0.80}, |
| "frontier": {"tier":4,"cost_input":0.005,"cost_output":0.015,"latency":1500,"strength":0.93}, |
| "specialist": {"tier":5,"cost_input":0.01,"cost_output":0.03,"latency":2000,"strength":0.97}, |
| } |
| TIER_TO_MODEL = {1:"tiny_local",2:"cheap_cloud",3:"medium",4:"frontier",5:"specialist"} |
| TIER_COST_MULT = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5} |
|
|
| TASK_TEMPLATES = { |
| "quick_answer": [ |
| "What is the capital of France?","Explain quantum computing briefly.", |
| "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?", |
| "What is the speed of light?","List the primary colors.", |
| "What is GDP?","When was the Declaration of Independence signed?", |
| ], |
| "coding": [ |
| "Write a Python function to reverse a linked list.", |
| "Fix the bug in this React component.","Refactor auth module to JWT.", |
| "Implement LRU cache in Go.","Debug segfault in C++ thread pool.", |
| "Add unit tests for the payment module.","Optimize this SQL query.", |
| "Create a REST API for user management.","Implement binary search in Rust.", |
| "Write a recursive descent parser for JSON.", |
| ], |
| "research": [ |
| "Research latest transformer advances.", |
| "Find sources comparing LoRA and full FT.", |
| "Investigate data center climate impact.", |
| "What does literature say on speculative decoding?", |
| "Survey privacy-preserving ML techniques.", |
| "Compare reinforcement learning algorithms for robotics.", |
| "Analyze trends in LLM scaling laws.", |
| ], |
| "document_drafting": [ |
| "Draft project proposal for ML pipeline.", |
| "Write email to team about deployment.", |
| "Create technical report on performance.", |
| "Write a project brief for the migration.", |
| "Draft meeting notes summary.", |
| ], |
| "legal_regulated": [ |
| "Review this contract for liability clauses.", |
| "Check GDPR compliance for data pipeline.", |
| "Draft privacy policy section.", |
| "Analyze indemnification clause in vendor agreement.", |
| "Verify regulatory compliance for medical device software.", |
| ], |
| "tool_heavy": [ |
| "Search open issues and create summary.", |
| "Fetch API docs and generate client code.", |
| "Query Q3 sales and produce chart.", |
| "Aggregate metrics from 5 monitoring endpoints.", |
| ], |
| "retrieval_heavy": [ |
| "Answer based on 50-page document.", |
| "Find all payment processing mentions.", |
| "Retrieve relevant cases for legal query.", |
| "Summarize the quarterly earnings report.", |
| ], |
| "long_horizon": [ |
| "Plan 3-month roadmap.","Orchestrate multi-region deployment.", |
| "Redesign data architecture end-to-end.","Migrate monolith to microservices.", |
| ], |
| "unknown_ambiguous": [ |
| "Help me with this thing.","I need something about the server.", |
| "Can you look into that issue?","There's a problem with the data.", |
| ], |
| } |
|
|
| def tier_success_prob(tier, difficulty): |
| strength = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}.get(tier,0.5) |
| return strength ** (difficulty * 0.6) |
|
|
| def generate_training_trace(idx, rng): |
| task_types = list(TASK_TEMPLATES.keys()) |
| task_type = rng.choice(task_types) |
| difficulty = { |
| "quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2, |
| "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5, |
| }[task_type] |
|
|
| |
| tier_outcomes = {} |
| for tier in range(1, 6): |
| sp = tier_success_prob(tier, difficulty) |
| tier_outcomes[tier] = rng.random() < sp |
|
|
| optimal_tier = 5 |
| for tier in range(1, 6): |
| if tier_outcomes.get(tier, False): |
| optimal_tier = tier |
| break |
|
|
| actual_tier = rng.choice(list(range(1, 6))) |
| |
| if difficulty <= 2: |
| actual_tier = rng.choices([1,2,3,4,5], weights=[3,4,2,1,0.5])[0] |
| elif difficulty == 3: |
| actual_tier = rng.choices([1,2,3,4,5], weights=[1,2,4,2,1])[0] |
| elif difficulty == 4: |
| actual_tier = rng.choices([1,2,3,4,5], weights=[0.5,1,2,4,2])[0] |
| else: |
| actual_tier = rng.choices([1,2,3,4,5], weights=[0.2,0.5,1,3,4])[0] |
|
|
| outcome = "success" if tier_outcomes.get(actual_tier, False) else "failure" |
|
|
| user_request = rng.choice(TASK_TEMPLATES[task_type]) |
| cost_mult = TIER_COST_MULT[actual_tier] |
| base_tokens = rng.randint(800, 12000) + rng.randint(200, 6000) |
| cost = base_tokens / 1000 * MODEL_CONFIGS[TIER_TO_MODEL[actual_tier]]["cost_input"] * cost_mult |
|
|
| return { |
| "trace_id": f"train_{idx}", |
| "user_request": user_request, |
| "task_type": task_type, |
| "difficulty": difficulty, |
| "actual_tier": actual_tier, |
| "optimal_tier": optimal_tier, |
| "outcome": outcome, |
| "cost": cost, |
| "tier_outcomes": {str(k): v for k, v in tier_outcomes.items()}, |
| "metadata": {"difficulty": difficulty, "optimal_tier": optimal_tier, "actual_tier": actual_tier}, |
| } |
|
|