from __future__ import annotations import pytest from llmserve_env.models import ServeAction, default_action from server.llmserve_environment import LLMServeEnvironment def _make_env(task_id: str = "static_workload", seed: int = 42) -> LLMServeEnvironment: env = LLMServeEnvironment(seed=seed) env.reset(task_id=task_id, seed=seed) return env def test_reset_returns_observation() -> None: env = _make_env() obs = env.observations[-1] assert obs.task_id == "static_workload" assert obs.step_index == 0 assert obs.done is False def test_reset_respects_requested_task_id() -> None: env = _make_env(task_id="adversarial_multitenant") obs = env.observations[-1] assert env.state.task_id == "adversarial_multitenant" assert obs.task_id == "adversarial_multitenant" assert obs.metadata["task_name"] == "Adversarial Multi-Tenant Serving" def test_serve_action_defaults_are_valid() -> None: action = ServeAction() assert action.batch_cap >= 1 assert action.kv_budget_fraction >= 0.1 def test_serve_action_normalizes_invalid_web_values() -> None: action = ServeAction( batch_cap=0, kv_budget_fraction=30, speculation_depth=40, quantization_tier="8", ) assert action.batch_cap == 1 assert action.kv_budget_fraction == 1.0 assert action.speculation_depth == 8 assert action.quantization_tier == "FP16" def test_serve_action_schema_exposes_quantization_enum() -> None: schema = ServeAction.model_json_schema() field = schema["properties"]["quantization_tier"] assert field["enum"] == ["FP16", "INT8", "INT4"] def test_reset_creates_unique_episode_id() -> None: env = LLMServeEnvironment(seed=1) env.reset(task_id="static_workload", seed=1) first = env.state.episode_id env.reset(task_id="static_workload", seed=2) second = env.state.episode_id assert first != second def test_step_returns_observation_with_reward() -> None: env = _make_env() obs = env.step(default_action()) assert obs.step_index == 1 assert isinstance(obs.reward, float) assert isinstance(obs.done, bool) def test_step_before_reset_raises() -> None: env = LLMServeEnvironment(seed=2) with pytest.raises(RuntimeError, match="reset"): env.step(default_action()) def test_step_updates_state() -> None: env = _make_env() env.step(default_action()) assert env.state.step_count == 1 assert env.state.elapsed_simulated_time_s > 0 def test_done_after_max_steps() -> None: env = _make_env("static_workload") obs = env.observations[-1] while not obs.done: obs = env.step(default_action()) assert env.state.done is True repeated = env.step(default_action()) assert repeated.done is True assert "message" in repeated.metadata def test_export_episode_log() -> None: env = _make_env() for _ in range(3): env.step(default_action()) log = env.export_episode_log() assert len(log.actions) == 3 assert len(log.rewards) == 3 assert len(log.observations) == 4