#!/usr/bin/env python3 """Train a learned model router for Agent Cost Optimizer. Architecture: CARROT-style plug-in router. - Per-tier P(success|query) classifiers (XGBoost) - Per-tier cost estimators - Route to cheapest tier where P(success) > threshold Ground truth: optimal_tier = min{tier : success(tier)} from execution traces. """ import json, os, sys, random, pickle, uuid import numpy as np from datetime import datetime, timedelta from collections import defaultdict from typing import Dict, List, Tuple, Any, Optional from dataclasses import dataclass from enum import Enum # ─── Feature Extraction ─────────────────────────────────────────── TASK_TYPES = [ "quick_answer", "coding", "research", "document_drafting", "legal_regulated", "tool_heavy", "retrieval_heavy", "long_horizon", "unknown_ambiguous", ] TASK_TYPE_TO_IDX = {t: i for i, t in enumerate(TASK_TYPES)} CODE_KW = ["python","javascript","code","function","bug","debug","refactor", "implement","test","compile","runtime","class","module","package", "async","thread","queue","stack","heap","pointer","segfault","linter"] LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy", "regulatory","liability","clause","indemnification","tos"] RESEARCH_KW = ["research","find sources","literature","investigate","compare", "analyze","study","survey","paper","arxiv","citation"] TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape", "lookup","download","upload","index","aggregate"] LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate", "pipeline","end-to-end","architecture","workflow","deploy"] MATH_KW = ["calculate","compute","solve","equation","formula","optimize", "probability","integral","derivative","matrix"] def extract_features(request: str, task_type: str, metadata: Dict = None) -> Dict[str, Any]: r = request.lower() feats = { "request_length": len(request), "num_words": len(request.split()), "num_sentences": request.count(".") + request.count("!") + request.count("?"), "has_code_kw": int(any(kw in r for kw in CODE_KW)), "num_code_kw": sum(1 for kw in CODE_KW if kw in r), "has_legal_kw": int(any(kw in r for kw in LEGAL_KW)), "num_legal_kw": sum(1 for kw in LEGAL_KW if kw in r), "has_research_kw": int(any(kw in r for kw in RESEARCH_KW)), "num_research_kw": sum(1 for kw in RESEARCH_KW if kw in r), "has_tool_kw": int(any(kw in r for kw in TOOL_KW)), "num_tool_kw": sum(1 for kw in TOOL_KW if kw in r), "has_long_kw": int(any(kw in r for kw in LONG_KW)), "has_math_kw": int(any(kw in r for kw in MATH_KW)), "task_type_idx": TASK_TYPE_TO_IDX.get(task_type, 8), } # One-hot task type for tt in TASK_TYPES: feats[f"tt_{tt}"] = int(task_type == tt) if metadata: feats["difficulty"] = metadata.get("difficulty", 3) return feats def feats_to_array(feats: Dict) -> List[float]: """Convert feature dict to fixed-order array.""" keys = sorted(feats.keys()) return [float(feats[k]) for k in keys] FEAT_KEYS = None # set on first call def feats_to_array_safe(feats: Dict) -> List[float]: global FEAT_KEYS if FEAT_KEYS is None: FEAT_KEYS = sorted(feats.keys()) return [float(feats.get(k, 0.0)) for k in FEAT_KEYS]