File size: 3,097 Bytes
4fbc241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | from __future__ import annotations
import pytest
from llmserve_env.models import ServeAction, default_action
from server.llmserve_environment import LLMServeEnvironment
def _make_env(task_id: str = "static_workload", seed: int = 42) -> LLMServeEnvironment:
env = LLMServeEnvironment(seed=seed)
env.reset(task_id=task_id, seed=seed)
return env
def test_reset_returns_observation() -> None:
env = _make_env()
obs = env.observations[-1]
assert obs.task_id == "static_workload"
assert obs.step_index == 0
assert obs.done is False
def test_reset_respects_requested_task_id() -> None:
env = _make_env(task_id="adversarial_multitenant")
obs = env.observations[-1]
assert env.state.task_id == "adversarial_multitenant"
assert obs.task_id == "adversarial_multitenant"
assert obs.metadata["task_name"] == "Adversarial Multi-Tenant Serving"
def test_serve_action_defaults_are_valid() -> None:
action = ServeAction()
assert action.batch_cap >= 1
assert action.kv_budget_fraction >= 0.1
def test_serve_action_normalizes_invalid_web_values() -> None:
action = ServeAction(
batch_cap=0,
kv_budget_fraction=30,
speculation_depth=40,
quantization_tier="8",
)
assert action.batch_cap == 1
assert action.kv_budget_fraction == 1.0
assert action.speculation_depth == 8
assert action.quantization_tier == "FP16"
def test_serve_action_schema_exposes_quantization_enum() -> None:
schema = ServeAction.model_json_schema()
field = schema["properties"]["quantization_tier"]
assert field["enum"] == ["FP16", "INT8", "INT4"]
def test_reset_creates_unique_episode_id() -> None:
env = LLMServeEnvironment(seed=1)
env.reset(task_id="static_workload", seed=1)
first = env.state.episode_id
env.reset(task_id="static_workload", seed=2)
second = env.state.episode_id
assert first != second
def test_step_returns_observation_with_reward() -> None:
env = _make_env()
obs = env.step(default_action())
assert obs.step_index == 1
assert isinstance(obs.reward, float)
assert isinstance(obs.done, bool)
def test_step_before_reset_raises() -> None:
env = LLMServeEnvironment(seed=2)
with pytest.raises(RuntimeError, match="reset"):
env.step(default_action())
def test_step_updates_state() -> None:
env = _make_env()
env.step(default_action())
assert env.state.step_count == 1
assert env.state.elapsed_simulated_time_s > 0
def test_done_after_max_steps() -> None:
env = _make_env("static_workload")
obs = env.observations[-1]
while not obs.done:
obs = env.step(default_action())
assert env.state.done is True
repeated = env.step(default_action())
assert repeated.done is True
assert "message" in repeated.metadata
def test_export_episode_log() -> None:
env = _make_env()
for _ in range(3):
env.step(default_action())
log = env.export_episode_log()
assert len(log.actions) == 3
assert len(log.rewards) == 3
assert len(log.observations) == 4
|