| from __future__ import annotations |
|
|
| from llmserve_env.models import EpisodeLog |
| from server.baseline_agent import HeuristicPolicy |
| from server.llmserve_environment import LLMServeEnvironment |
| from server.optimal_solver import OptimalSolver |
|
|
|
|
| class GraderEngine: |
| _shared_ppo_baselines: dict[str, float] = {} |
| _shared_heuristic_baselines: dict[str, float] = {} |
|
|
| def __init__(self) -> None: |
| self.optimal_solver = OptimalSolver() |
| self._ppo_baselines = self._shared_ppo_baselines |
| self._heuristic_baselines = self._shared_heuristic_baselines |
|
|
| def _run_policy_episode(self, task_id: str, seed: int, policy) -> float: |
| env = LLMServeEnvironment(seed=seed, mode="sim") |
| if hasattr(policy, "reset"): |
| policy.reset() |
| observation = env.reset(seed=seed, task_id=task_id) |
| task_cfg = env.task_config or {} |
| max_steps = int(task_cfg.get("max_steps", 60)) |
| for _ in range(max_steps): |
| action = policy.act(observation, task_id) |
| observation = env.step(action) |
| if bool(getattr(observation, "done", False)): |
| break |
| raw_score, _ = self._compute_raw_score(env.export_episode_log()) |
| return raw_score |
|
|
| def get_ppo_baseline(self, task_id: str) -> float: |
| if task_id in self._ppo_baselines: |
| return self._ppo_baselines[task_id] |
|
|
| try: |
| from agents.ppo_agent import PPOAgent, find_weights |
|
|
| weights_path = find_weights(task_id) |
| if not weights_path: |
| heuristic_baseline = self.get_heuristic_baseline(task_id) |
| self._ppo_baselines[task_id] = heuristic_baseline |
| return heuristic_baseline |
|
|
| agent = PPOAgent(weights_path) |
| baseline = self._run_policy_episode(task_id, 42, agent) |
| self._ppo_baselines[task_id] = baseline |
| return baseline |
| except Exception: |
| heuristic_baseline = self.get_heuristic_baseline(task_id) |
| self._ppo_baselines[task_id] = heuristic_baseline |
| return heuristic_baseline |
|
|
| def get_heuristic_baseline(self, task_id: str) -> float: |
| if task_id in self._heuristic_baselines: |
| return self._heuristic_baselines[task_id] |
|
|
| policy = HeuristicPolicy() |
| baseline = self._run_policy_episode(task_id, 142, policy) |
| self._heuristic_baselines[task_id] = baseline |
| return baseline |
|
|
| def _compute_raw_score(self, episode_log: EpisodeLog) -> tuple[float, dict[str, float]]: |
| observations = episode_log.observations |
| if not observations: |
| return 0.0, {"throughput": 0.0, "slo": 0.0, "memory": 0.0, "cost": 0.0} |
| |
| oracle = self.optimal_solver.oracle_reference(episode_log.task_id) |
| mean_throughput = sum(obs.throughput_tps for obs in observations) / len(observations) |
| mean_slo = sum(obs.slo_compliance_rate for obs in observations) / len(observations) |
| mean_memory = sum(obs.gpu_memory_used_gb for obs in observations) / len(observations) |
| mean_cost = sum(obs.estimated_cost_per_1k for obs in observations) / len(observations) |
|
|
| throughput_component = min(1.0, mean_throughput / oracle["throughput_tps"]) |
| slo_component = min(1.0, mean_slo / oracle["slo_compliance_rate"]) |
| memory_component = max(0.0, 1.0 - max(0.0, mean_memory - 38.0) / 38.0) |
| cost_component = max(0.0, 1.0 - max(0.0, mean_cost - oracle["cost_per_1k"]) / max(oracle["cost_per_1k"], 1e-6)) |
|
|
| score = ( |
| 0.30 * throughput_component |
| + 0.35 * slo_component |
| + 0.20 * memory_component |
| + 0.15 * cost_component |
| ) |
| return max(0.0, min(1.0, score)), { |
| "throughput": round(throughput_component, 4), |
| "slo": round(slo_component, 4), |
| "memory": round(memory_component, 4), |
| "cost": round(cost_component, 4), |
| } |
|
|
| def grade(self, episode_log: EpisodeLog, actions_taken: int | None = None) -> dict[str, object]: |
| resolved_actions_taken = actions_taken if actions_taken is not None else len(episode_log.actions) |
| if not episode_log.observations: |
| return { |
| "task_id": episode_log.task_id, |
| "actions_taken": resolved_actions_taken, |
| "score": 0.0, |
| "breakdown": {"throughput": 0.0, "slo": 0.0, "memory": 0.0, "cost": 0.0}, |
| } |
| |
| raw_score, breakdown = self._compute_raw_score(episode_log) |
| heuristic_baseline = self.get_heuristic_baseline(episode_log.task_id) |
| ppo_baseline = self.get_ppo_baseline(episode_log.task_id) |
| anchor = max(heuristic_baseline, ppo_baseline, 1e-6) |
|
|
| if raw_score <= anchor: |
| final_score = 0.5 * (raw_score / anchor) |
| else: |
| final_score = 0.5 + 0.5 * ((raw_score - anchor) / max(1.0 - anchor, 1e-6)) |
|
|
| return { |
| "task_id": episode_log.task_id, |
| "actions_taken": resolved_actions_taken, |
| "score": round(max(0.0, min(1.0, final_score)), 4), |
| "breakdown": breakdown, |
| "heuristic_baseline": round(heuristic_baseline, 4), |
| "ppo_baseline": round(ppo_baseline, 4), |
| "raw_score": round(raw_score, 4), |
| } |
|
|