File size: 3,866 Bytes
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Heuristic baseline policy for LLM serving configuration.

Rules derived from three papers:
  - Orca (OSDI 2022): dynamic iteration-level batching / queue management
  - vLLM / PagedAttention (SOSP 2023): KV cache memory management
  - Decima (SIGCOMM 2019): workload-adaptive scheduling via RL
"""
from __future__ import annotations

from llmserve_env.models import QuantizationTier, ServeAction, ServeObservation


class HeuristicPolicy:
    """Reactive heuristic agent that adjusts serving config based on observations."""

    def __init__(self) -> None:
        self.batch_cap = 32
        self.kv_budget_fraction = 0.70
        self.speculation_depth = 0
        self.quantization_tier: str = QuantizationTier.FP16.value
        self.prefill_decode_split = False
        self.priority_routing = False

    def reset(self) -> None:
        """Reset to starting state for a new episode."""
        self.batch_cap = 32
        self.kv_budget_fraction = 0.70
        self.speculation_depth = 0
        self.quantization_tier = QuantizationTier.FP16.value
        self.prefill_decode_split = False
        self.priority_routing = False

    def act(self, observation: ServeObservation, task_id: str) -> ServeAction:
        """Produce an action given the current observation."""

        # --- Orca rules: dynamic batching / queue management ---
        if observation.slo_compliance_rate < 0.85:
            self.batch_cap = max(1, self.batch_cap - 32)
        elif observation.queue_depth > 0.7 * self.batch_cap:
            self.batch_cap = min(512, self.batch_cap + 16)
        elif observation.queue_depth < 0.2 * self.batch_cap and self.batch_cap > 16:
            self.batch_cap = max(1, self.batch_cap - 16)

        # --- vLLM / PagedAttention rules: memory management ---
        if observation.eviction_events > 0:
            self.kv_budget_fraction = 0.60
        elif observation.kv_cache_occupancy > 0.85:
            self.kv_budget_fraction = max(0.10, self.kv_budget_fraction - 0.10)
        elif observation.kv_cache_occupancy < 0.50 and self.kv_budget_fraction < 1.0:
            self.kv_budget_fraction = min(1.0, self.kv_budget_fraction + 0.10)

        # --- Decima rules: workload-adaptive optimisation ---
        if observation.request_arrival_rate > 25:
            self.quantization_tier = QuantizationTier.INT8.value
        elif observation.request_arrival_rate < 8:
            self.quantization_tier = QuantizationTier.FP16.value

        if observation.mean_prompt_length > 800:
            self.speculation_depth = 0
        elif observation.mean_prompt_length < 200:
            self.speculation_depth = 4

        # Use priority routing on adversarial task with long prompts
        if task_id == "adversarial_multitenant" and observation.mean_prompt_length > 2000:
            self.priority_routing = True
        else:
            self.priority_routing = False

        # Enable chunked prefill when under high queue pressure
        self.prefill_decode_split = observation.queue_depth > 0.5 * self.batch_cap

        return ServeAction(
            batch_cap=self.batch_cap,
            kv_budget_fraction=round(self.kv_budget_fraction, 2),
            speculation_depth=self.speculation_depth,
            quantization_tier=self.quantization_tier,
            prefill_decode_split=self.prefill_decode_split,
            priority_routing=self.priority_routing,
        )


# ---------------------------------------------------------------------------
# Legacy function interface for backward-compatibility
# ---------------------------------------------------------------------------
_default_policy = HeuristicPolicy()


def baseline_policy(observation: ServeObservation, task_id: str) -> ServeAction:
    """Drop-in replacement preserving the old function signature."""
    return _default_policy.act(observation, task_id)