narcolepticchicken commited on
Commit
8275b23
Β·
verified Β·
1 Parent(s): c122389

Upload train_router.py

Browse files
Files changed (1) hide show
  1. train_router.py +77 -0
train_router.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train a learned model router for Agent Cost Optimizer.
3
+
4
+ Architecture: CARROT-style plug-in router.
5
+ - Per-tier P(success|query) classifiers (XGBoost)
6
+ - Per-tier cost estimators
7
+ - Route to cheapest tier where P(success) > threshold
8
+
9
+ Ground truth: optimal_tier = min{tier : success(tier)} from execution traces.
10
+ """
11
+ import json, os, sys, random, pickle, uuid
12
+ import numpy as np
13
+ from datetime import datetime, timedelta
14
+ from collections import defaultdict
15
+ from typing import Dict, List, Tuple, Any, Optional
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+
19
+ # ─── Feature Extraction ───────────────────────────────────────────
20
+ TASK_TYPES = [
21
+ "quick_answer", "coding", "research", "document_drafting",
22
+ "legal_regulated", "tool_heavy", "retrieval_heavy",
23
+ "long_horizon", "unknown_ambiguous",
24
+ ]
25
+ TASK_TYPE_TO_IDX = {t: i for i, t in enumerate(TASK_TYPES)}
26
+
27
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
28
+ "implement","test","compile","runtime","class","module","package",
29
+ "async","thread","queue","stack","heap","pointer","segfault","linter"]
30
+ LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy",
31
+ "regulatory","liability","clause","indemnification","tos"]
32
+ RESEARCH_KW = ["research","find sources","literature","investigate","compare",
33
+ "analyze","study","survey","paper","arxiv","citation"]
34
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape",
35
+ "lookup","download","upload","index","aggregate"]
36
+ LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate",
37
+ "pipeline","end-to-end","architecture","workflow","deploy"]
38
+ MATH_KW = ["calculate","compute","solve","equation","formula","optimize",
39
+ "probability","integral","derivative","matrix"]
40
+
41
+ def extract_features(request: str, task_type: str, metadata: Dict = None) -> Dict[str, Any]:
42
+ r = request.lower()
43
+ feats = {
44
+ "request_length": len(request),
45
+ "num_words": len(request.split()),
46
+ "num_sentences": request.count(".") + request.count("!") + request.count("?"),
47
+ "has_code_kw": int(any(kw in r for kw in CODE_KW)),
48
+ "num_code_kw": sum(1 for kw in CODE_KW if kw in r),
49
+ "has_legal_kw": int(any(kw in r for kw in LEGAL_KW)),
50
+ "num_legal_kw": sum(1 for kw in LEGAL_KW if kw in r),
51
+ "has_research_kw": int(any(kw in r for kw in RESEARCH_KW)),
52
+ "num_research_kw": sum(1 for kw in RESEARCH_KW if kw in r),
53
+ "has_tool_kw": int(any(kw in r for kw in TOOL_KW)),
54
+ "num_tool_kw": sum(1 for kw in TOOL_KW if kw in r),
55
+ "has_long_kw": int(any(kw in r for kw in LONG_KW)),
56
+ "has_math_kw": int(any(kw in r for kw in MATH_KW)),
57
+ "task_type_idx": TASK_TYPE_TO_IDX.get(task_type, 8),
58
+ }
59
+ # One-hot task type
60
+ for tt in TASK_TYPES:
61
+ feats[f"tt_{tt}"] = int(task_type == tt)
62
+ if metadata:
63
+ feats["difficulty"] = metadata.get("difficulty", 3)
64
+ return feats
65
+
66
+ def feats_to_array(feats: Dict) -> List[float]:
67
+ """Convert feature dict to fixed-order array."""
68
+ keys = sorted(feats.keys())
69
+ return [float(feats[k]) for k in keys]
70
+
71
+ FEAT_KEYS = None # set on first call
72
+
73
+ def feats_to_array_safe(feats: Dict) -> List[float]:
74
+ global FEAT_KEYS
75
+ if FEAT_KEYS is None:
76
+ FEAT_KEYS = sorted(feats.keys())
77
+ return [float(feats.get(k, 0.0)) for k in FEAT_KEYS]