vegarl / server /grader.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
from __future__ import annotations
from llmserve_env.models import EpisodeLog
from server.baseline_agent import HeuristicPolicy
from server.llmserve_environment import LLMServeEnvironment
from server.optimal_solver import OptimalSolver
class GraderEngine:
_shared_ppo_baselines: dict[str, float] = {}
_shared_heuristic_baselines: dict[str, float] = {}
def __init__(self) -> None:
self.optimal_solver = OptimalSolver()
self._ppo_baselines = self._shared_ppo_baselines
self._heuristic_baselines = self._shared_heuristic_baselines
def _run_policy_episode(self, task_id: str, seed: int, policy) -> float:
env = LLMServeEnvironment(seed=seed, mode="sim")
if hasattr(policy, "reset"):
policy.reset()
observation = env.reset(seed=seed, task_id=task_id)
task_cfg = env.task_config or {}
max_steps = int(task_cfg.get("max_steps", 60))
for _ in range(max_steps):
action = policy.act(observation, task_id)
observation = env.step(action)
if bool(getattr(observation, "done", False)):
break
raw_score, _ = self._compute_raw_score(env.export_episode_log())
return raw_score
def get_ppo_baseline(self, task_id: str) -> float:
if task_id in self._ppo_baselines:
return self._ppo_baselines[task_id]
try:
from agents.ppo_agent import PPOAgent, find_weights
weights_path = find_weights(task_id)
if not weights_path:
heuristic_baseline = self.get_heuristic_baseline(task_id)
self._ppo_baselines[task_id] = heuristic_baseline
return heuristic_baseline
agent = PPOAgent(weights_path)
baseline = self._run_policy_episode(task_id, 42, agent)
self._ppo_baselines[task_id] = baseline
return baseline
except Exception:
heuristic_baseline = self.get_heuristic_baseline(task_id)
self._ppo_baselines[task_id] = heuristic_baseline
return heuristic_baseline
def get_heuristic_baseline(self, task_id: str) -> float:
if task_id in self._heuristic_baselines:
return self._heuristic_baselines[task_id]
policy = HeuristicPolicy()
baseline = self._run_policy_episode(task_id, 142, policy)
self._heuristic_baselines[task_id] = baseline
return baseline
def _compute_raw_score(self, episode_log: EpisodeLog) -> tuple[float, dict[str, float]]:
observations = episode_log.observations
if not observations:
return 0.0, {"throughput": 0.0, "slo": 0.0, "memory": 0.0, "cost": 0.0}
oracle = self.optimal_solver.oracle_reference(episode_log.task_id)
mean_throughput = sum(obs.throughput_tps for obs in observations) / len(observations)
mean_slo = sum(obs.slo_compliance_rate for obs in observations) / len(observations)
mean_memory = sum(obs.gpu_memory_used_gb for obs in observations) / len(observations)
mean_cost = sum(obs.estimated_cost_per_1k for obs in observations) / len(observations)
throughput_component = min(1.0, mean_throughput / oracle["throughput_tps"])
slo_component = min(1.0, mean_slo / oracle["slo_compliance_rate"])
memory_component = max(0.0, 1.0 - max(0.0, mean_memory - 38.0) / 38.0)
cost_component = max(0.0, 1.0 - max(0.0, mean_cost - oracle["cost_per_1k"]) / max(oracle["cost_per_1k"], 1e-6))
score = (
0.30 * throughput_component
+ 0.35 * slo_component
+ 0.20 * memory_component
+ 0.15 * cost_component
)
return max(0.0, min(1.0, score)), {
"throughput": round(throughput_component, 4),
"slo": round(slo_component, 4),
"memory": round(memory_component, 4),
"cost": round(cost_component, 4),
}
def grade(self, episode_log: EpisodeLog, actions_taken: int | None = None) -> dict[str, object]:
resolved_actions_taken = actions_taken if actions_taken is not None else len(episode_log.actions)
if not episode_log.observations:
return {
"task_id": episode_log.task_id,
"actions_taken": resolved_actions_taken,
"score": 0.0,
"breakdown": {"throughput": 0.0, "slo": 0.0, "memory": 0.0, "cost": 0.0},
}
raw_score, breakdown = self._compute_raw_score(episode_log)
heuristic_baseline = self.get_heuristic_baseline(episode_log.task_id)
ppo_baseline = self.get_ppo_baseline(episode_log.task_id)
anchor = max(heuristic_baseline, ppo_baseline, 1e-6)
if raw_score <= anchor:
final_score = 0.5 * (raw_score / anchor)
else:
final_score = 0.5 + 0.5 * ((raw_score - anchor) / max(1.0 - anchor, 1e-6))
return {
"task_id": episode_log.task_id,
"actions_taken": resolved_actions_taken,
"score": round(max(0.0, min(1.0, final_score)), 4),
"breakdown": breakdown,
"heuristic_baseline": round(heuristic_baseline, 4),
"ppo_baseline": round(ppo_baseline, 4),
"raw_score": round(raw_score, 4),
}