File size: 5,522 Bytes
d4a573d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

# ─── Synthetic Trace Generator (for training data) ────────────────
MODEL_CONFIGS = {
    "tiny_local": {"tier":1,"cost_input":0.0001,"cost_output":0.0002,"latency":200,"strength":0.35},
    "cheap_cloud": {"tier":2,"cost_input":0.00015,"cost_output":0.0006,"latency":400,"strength":0.55},
    "medium": {"tier":3,"cost_input":0.0015,"cost_output":0.006,"latency":800,"strength":0.80},
    "frontier": {"tier":4,"cost_input":0.005,"cost_output":0.015,"latency":1500,"strength":0.93},
    "specialist": {"tier":5,"cost_input":0.01,"cost_output":0.03,"latency":2000,"strength":0.97},
}
TIER_TO_MODEL = {1:"tiny_local",2:"cheap_cloud",3:"medium",4:"frontier",5:"specialist"}
TIER_COST_MULT = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}

TASK_TEMPLATES = {
    "quick_answer": [
        "What is the capital of France?","Explain quantum computing briefly.",
        "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?",
        "What is the speed of light?","List the primary colors.",
        "What is GDP?","When was the Declaration of Independence signed?",
    ],
    "coding": [
        "Write a Python function to reverse a linked list.",
        "Fix the bug in this React component.","Refactor auth module to JWT.",
        "Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
        "Add unit tests for the payment module.","Optimize this SQL query.",
        "Create a REST API for user management.","Implement binary search in Rust.",
        "Write a recursive descent parser for JSON.",
    ],
    "research": [
        "Research latest transformer advances.",
        "Find sources comparing LoRA and full FT.",
        "Investigate data center climate impact.",
        "What does literature say on speculative decoding?",
        "Survey privacy-preserving ML techniques.",
        "Compare reinforcement learning algorithms for robotics.",
        "Analyze trends in LLM scaling laws.",
    ],
    "document_drafting": [
        "Draft project proposal for ML pipeline.",
        "Write email to team about deployment.",
        "Create technical report on performance.",
        "Write a project brief for the migration.",
        "Draft meeting notes summary.",
    ],
    "legal_regulated": [
        "Review this contract for liability clauses.",
        "Check GDPR compliance for data pipeline.",
        "Draft privacy policy section.",
        "Analyze indemnification clause in vendor agreement.",
        "Verify regulatory compliance for medical device software.",
    ],
    "tool_heavy": [
        "Search open issues and create summary.",
        "Fetch API docs and generate client code.",
        "Query Q3 sales and produce chart.",
        "Aggregate metrics from 5 monitoring endpoints.",
    ],
    "retrieval_heavy": [
        "Answer based on 50-page document.",
        "Find all payment processing mentions.",
        "Retrieve relevant cases for legal query.",
        "Summarize the quarterly earnings report.",
    ],
    "long_horizon": [
        "Plan 3-month roadmap.","Orchestrate multi-region deployment.",
        "Redesign data architecture end-to-end.","Migrate monolith to microservices.",
    ],
    "unknown_ambiguous": [
        "Help me with this thing.","I need something about the server.",
        "Can you look into that issue?","There's a problem with the data.",
    ],
}

def tier_success_prob(tier, difficulty):
    strength = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}.get(tier,0.5)
    return strength ** (difficulty * 0.6)

def generate_training_trace(idx, rng):
    task_types = list(TASK_TEMPLATES.keys())
    task_type = rng.choice(task_types)
    difficulty = {
        "quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
        "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5,
    }[task_type]

    # Try ALL tiers for this task to get ground truth
    tier_outcomes = {}
    for tier in range(1, 6):
        sp = tier_success_prob(tier, difficulty)
        tier_outcomes[tier] = rng.random() < sp

    optimal_tier = 5  # default: need strongest
    for tier in range(1, 6):
        if tier_outcomes.get(tier, False):
            optimal_tier = tier
            break

    actual_tier = rng.choice(list(range(1, 6)))
    # Bias toward reasonable tiers
    if difficulty <= 2:
        actual_tier = rng.choices([1,2,3,4,5], weights=[3,4,2,1,0.5])[0]
    elif difficulty == 3:
        actual_tier = rng.choices([1,2,3,4,5], weights=[1,2,4,2,1])[0]
    elif difficulty == 4:
        actual_tier = rng.choices([1,2,3,4,5], weights=[0.5,1,2,4,2])[0]
    else:
        actual_tier = rng.choices([1,2,3,4,5], weights=[0.2,0.5,1,3,4])[0]

    outcome = "success" if tier_outcomes.get(actual_tier, False) else "failure"

    user_request = rng.choice(TASK_TEMPLATES[task_type])
    cost_mult = TIER_COST_MULT[actual_tier]
    base_tokens = rng.randint(800, 12000) + rng.randint(200, 6000)
    cost = base_tokens / 1000 * MODEL_CONFIGS[TIER_TO_MODEL[actual_tier]]["cost_input"] * cost_mult

    return {
        "trace_id": f"train_{idx}",
        "user_request": user_request,
        "task_type": task_type,
        "difficulty": difficulty,
        "actual_tier": actual_tier,
        "optimal_tier": optimal_tier,
        "outcome": outcome,
        "cost": cost,
        "tier_outcomes": {str(k): v for k, v in tier_outcomes.items()},
        "metadata": {"difficulty": difficulty, "optimal_tier": optimal_tier, "actual_tier": actual_tier},
    }