#!/usr/bin/env python3 """ Generate SFT data for Budget Router. Default path is deliberately zero-API-cost: distill the existing PPO hard_multi policy into chat transcripts, then push the dataset to the Hub for HF Jobs. Optional LLM labeling is available with --teacher llm, but it costs one large model call per environment step (20 calls per episode). """ from __future__ import annotations import argparse import json import math import os from pathlib import Path from typing import Any, Callable import numpy as np from budget_router.environment import BudgetRouterEnv from budget_router.models import Action, ActionType, Observation, TaskConfig from budget_router.policies import heuristic_baseline_policy from budget_router.reward import episode_metrics, grade_episode from budget_router.tasks import HARD_MULTI, TASK_PRESETS from inference import LLMRouter, SYSTEM_PROMPT VALID_ACTIONS = ["route_to_a", "route_to_b", "route_to_c", "shed_load"] PPO_ACTION_NAMES = ["route_to_a", "route_to_b", "route_to_c", "shed_load"] DEFAULT_DATASET_REPO = "akshay4/budget-router-sft-data" DEFAULT_PPO_MODEL_PATH = "trained_models/ppo_hard_multi_100k.zip" _PPO_POLICY_CACHE: dict[str, Callable[[Observation], str]] = {} def _obs_to_array(obs: Observation) -> np.ndarray: return np.array( [ obs.provider_a_status, obs.provider_b_status, obs.provider_c_status, obs.budget_remaining, obs.queue_backlog, obs.system_latency, obs.step_count, ], dtype=np.float32, ) def _steps_remaining(obs: Observation, max_steps: int = 20) -> int: elapsed = int(round(float(obs.step_count) * max_steps)) return max(0, max_steps - elapsed) def _trend_text(obs: Observation, previous_obs: Observation | None, previous2_obs: Observation | None) -> str: if previous2_obs is not None: ta = (obs.provider_a_status - previous2_obs.provider_a_status) / 2.0 tb = (obs.provider_b_status - previous2_obs.provider_b_status) / 2.0 tc = (obs.provider_c_status - previous2_obs.provider_c_status) / 2.0 return f"trend (avg/step, 2-step): A:{ta:+.3f} B:{tb:+.3f} C:{tc:+.3f}" if previous_obs is not None: ta = obs.provider_a_status - previous_obs.provider_a_status tb = obs.provider_b_status - previous_obs.provider_b_status tc = obs.provider_c_status - previous_obs.provider_c_status return f"trend (1-step only, noisy): A:{ta:+.3f} B:{tb:+.3f} C:{tc:+.3f}" return "trend: unavailable" def _budget_runway_text(obs: Observation, previous_obs: Observation | None) -> str: if previous_obs is None: return "budget_runway_at_current_rate: >20 steps" budget_spent = float(previous_obs.budget_remaining) - float(obs.budget_remaining) if budget_spent <= 0.001: return "budget_runway_at_current_rate: >20 steps" runway = int(float(obs.budget_remaining) / budget_spent) return f"budget_runway_at_current_rate: ~{runway} steps" def _previous_step_feedback(obs: Observation) -> str: metadata = getattr(obs, "metadata", None) or {} if not metadata.get("action_type"): return "" parts = [ "previous_step_feedback:", f" previous_action: {metadata.get('action_type')}", ] if obs.reward is not None: parts.append(f" previous_reward: {float(obs.reward):+.2f}") if metadata.get("request_succeeded") is not None: parts.append(f" previous_success: {str(bool(metadata.get('request_succeeded'))).lower()}") if metadata.get("cost") is not None: parts.append(f" previous_cost: {float(metadata.get('cost')):.2f}") if metadata.get("latency_ms") is not None: parts.append(f" previous_latency_ms: {float(metadata.get('latency_ms')):.2f}") if metadata.get("budget_exhausted"): parts.append(" previous_budget_exhausted: true") return "\n".join(parts) def format_observation_for_sft( *, obs: Observation, task_name: str, previous_obs: Observation | None, previous2_obs: Observation | None, ) -> str: """Public observation text used consistently for SFT train/eval.""" lines = [ f"task: {task_name}", f"provider_a_status: {obs.provider_a_status:.3f}", f"provider_b_status: {obs.provider_b_status:.3f}", f"provider_c_status: {obs.provider_c_status:.3f}", f"budget_remaining: {obs.budget_remaining:.3f}", f"queue_backlog: {obs.queue_backlog:.3f}", f"system_latency: {obs.system_latency:.3f}", f"step_count: {obs.step_count:.3f}", f"steps_remaining: {_steps_remaining(obs)}", _trend_text(obs, previous_obs, previous2_obs), _budget_runway_text(obs, previous_obs), ] feedback = _previous_step_feedback(obs) if feedback: lines.append(feedback) return "\n".join(lines) def run_heuristic_episode(task_cfg: TaskConfig, seed: int) -> dict[str, Any]: env = BudgetRouterEnv() obs = env.reset(seed=seed, scenario=task_cfg) total_reward = 0.0 while not obs.done: obs = env.step(heuristic_baseline_policy(obs)) total_reward += float(obs.reward or 0.0) grader = grade_episode(env._internal.history) return { "grader_score": float(grader["overall_score"]), "total_reward": total_reward, "grader": grader, } def _load_ppo_policy(model_path: str) -> Callable[[Observation], str]: if model_path in _PPO_POLICY_CACHE: return _PPO_POLICY_CACHE[model_path] try: from stable_baselines3 import PPO except ImportError as exc: raise RuntimeError( "PPO teacher requires training dependencies. Run `uv sync --extra training` " "or use --teacher heuristic/llm." ) from exc path = Path(model_path) if not path.exists(): raise FileNotFoundError(f"PPO model not found: {path}") model = PPO.load(str(path)) def choose(obs: Observation) -> str: action_idx, _ = model.predict(_obs_to_array(obs), deterministic=True) idx = int(action_idx) return PPO_ACTION_NAMES[idx] if 0 <= idx < len(PPO_ACTION_NAMES) else "shed_load" _PPO_POLICY_CACHE[model_path] = choose return choose def _load_llm_policy(task_name: str) -> Callable[[Observation], str]: api_key = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") if not api_key: raise RuntimeError("LLM teacher requires HF_TOKEN or API_KEY in the environment.") router = LLMRouter( api_base_url=os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1"), model_name=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct"), api_key=api_key, ) router.reset(task_name=task_name) def choose(obs: Observation) -> str: return router.choose_action(obs).action_type.value return choose def collect_teacher_episode( *, task_name: str, task_cfg: TaskConfig, seed: int, teacher: str, ppo_model_path: str, ) -> dict[str, Any]: if teacher == "ppo": choose_action = _load_ppo_policy(ppo_model_path) elif teacher == "heuristic": choose_action = lambda obs: heuristic_baseline_policy(obs).action_type.value elif teacher == "llm": choose_action = _load_llm_policy(task_name) else: raise ValueError(f"Unknown teacher {teacher!r}") env = BudgetRouterEnv() obs = env.reset(seed=seed, scenario=task_cfg) messages = [{"role": "system", "content": SYSTEM_PROMPT}] previous2_obs: Observation | None = None previous_obs: Observation | None = None actions: list[str] = [] total_reward = 0.0 while not obs.done: obs_text = format_observation_for_sft( obs=obs, task_name=task_name, previous_obs=previous_obs, previous2_obs=previous2_obs, ) action_str = choose_action(obs) if action_str not in VALID_ACTIONS: action_str = "shed_load" messages.append({"role": "user", "content": obs_text}) messages.append({"role": "assistant", "content": action_str}) actions.append(action_str) previous2_obs = previous_obs previous_obs = obs obs = env.step(Action(action_type=ActionType(action_str))) total_reward += float(obs.reward or 0.0) grader = grade_episode(env._internal.history) return { "seed": seed, "teacher": teacher, "messages": messages, "actions": actions, "grader_score": float(grader["overall_score"]), "total_reward": total_reward, "grader": grader, "metrics": episode_metrics(env._internal.history), } def select_training_rows( episodes: list[dict[str, Any]], *, top_fraction: float, min_keep: int, min_delta: float, ) -> list[dict[str, Any]]: ranked = sorted(episodes, key=lambda item: float(item["delta_vs_heuristic"]), reverse=True) target = max(min_keep, int(math.ceil(len(ranked) * top_fraction))) positive = [ep for ep in ranked if float(ep["delta_vs_heuristic"]) >= min_delta] source = positive if len(positive) >= min_keep else ranked return source[: min(target, len(source))] def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, sort_keys=True) + "\n") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Generate Budget Router SFT dataset.") parser.add_argument("--teacher", choices=["ppo", "heuristic", "llm"], default=os.getenv("TEACHER_POLICY", "ppo")) parser.add_argument("--task", default=os.getenv("TASK_NAME", "hard_multi"), choices=sorted(TASK_PRESETS)) parser.add_argument("--start-seed", type=int, default=int(os.getenv("SFT_START_SEED", "1000"))) parser.add_argument("--n-episodes", type=int, default=int(os.getenv("SFT_N_EPISODES", "100"))) parser.add_argument("--top-fraction", type=float, default=float(os.getenv("SFT_TOP_FRACTION", "0.30"))) parser.add_argument("--min-keep", type=int, default=int(os.getenv("SFT_MIN_KEEP", "20"))) parser.add_argument("--min-delta", type=float, default=float(os.getenv("SFT_MIN_DELTA", "0.0"))) parser.add_argument("--ppo-model-path", default=os.getenv("PPO_MODEL_PATH", DEFAULT_PPO_MODEL_PATH)) parser.add_argument("--dataset-repo", default=os.getenv("DATASET_REPO", DEFAULT_DATASET_REPO)) parser.add_argument("--local-jsonl", default=os.getenv("SFT_LOCAL_JSONL", "outputs/sft_dataset.jsonl")) parser.add_argument("--no-push", action="store_true", help="Write local JSONL only; do not push to Hub.") return parser.parse_args() def main() -> None: args = parse_args() task_cfg = TASK_PRESETS[args.task] seeds = list(range(args.start_seed, args.start_seed + args.n_episodes)) if args.teacher == "llm": print( f"[sft-data] teacher=llm n_episodes={args.n_episodes}; " f"expected large-model calls <= {args.n_episodes * task_cfg.max_steps}", flush=True, ) else: print(f"[sft-data] teacher={args.teacher} uses 0 large-LLM calls", flush=True) episodes: list[dict[str, Any]] = [] for i, seed in enumerate(seeds, start=1): teacher_ep = collect_teacher_episode( task_name=args.task, task_cfg=task_cfg, seed=seed, teacher=args.teacher, ppo_model_path=args.ppo_model_path, ) heuristic_ep = run_heuristic_episode(task_cfg, seed) delta = teacher_ep["grader_score"] - heuristic_ep["grader_score"] teacher_ep["heuristic_score"] = heuristic_ep["grader_score"] teacher_ep["delta_vs_heuristic"] = delta episodes.append(teacher_ep) print( f"[sft-data] {i:03d}/{len(seeds)} seed={seed} " f"teacher={teacher_ep['grader_score']:.4f} heuristic={heuristic_ep['grader_score']:.4f} " f"delta={delta:+.4f}", flush=True, ) kept = select_training_rows( episodes, top_fraction=args.top_fraction, min_keep=args.min_keep, min_delta=args.min_delta, ) dataset_rows = [ { "messages": ep["messages"], "seed": ep["seed"], "teacher": ep["teacher"], "teacher_score": ep["grader_score"], "heuristic_score": ep["heuristic_score"], "delta_vs_heuristic": ep["delta_vs_heuristic"], "actions": ep["actions"], } for ep in kept ] write_jsonl(Path(args.local_jsonl), dataset_rows) mean_all = sum(float(ep["grader_score"]) for ep in episodes) / len(episodes) mean_kept = sum(float(ep["grader_score"]) for ep in kept) / len(kept) mean_delta = sum(float(ep["delta_vs_heuristic"]) for ep in kept) / len(kept) print( "[sft-data] summary " f"generated={len(episodes)} kept={len(kept)} mean_all={mean_all:.4f} " f"mean_kept={mean_kept:.4f} mean_delta_kept={mean_delta:+.4f} " f"local_jsonl={args.local_jsonl}", flush=True, ) if not args.no_push: token = os.environ.get("HF_TOKEN") if not token: raise RuntimeError("HF_TOKEN must be set to push the dataset. Use --no-push for local only.") from datasets import Dataset Dataset.from_list(dataset_rows).push_to_hub(args.dataset_repo, token=token) print(f"[sft-data] pushed dataset to https://huggingface.co/datasets/{args.dataset_repo}", flush=True) if __name__ == "__main__": main()