File size: 4,365 Bytes
4fbc241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | import sys
import os
import numpy as np
from typing import List
# Ensure projects root is in path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from server.llmserve_environment import LLMServeEnvironment
from llmserve_env.models import ServeAction, QuantizationTier
def test_quantization_jitter():
print("[INFO] Testing Quantization Jitter (Chiron 2024)...")
env = LLMServeEnvironment(seed=42)
# FP16 Jitter
env.reset(task_id="static_workload")
fp16_latencies = []
for _ in range(50): # Avoid 100-step Mega-Prompt spike
obs = env.step(ServeAction(quantization_tier=QuantizationTier.FP16.value, batch_cap=200))
fp16_latencies.append(obs.p50_ttft_ms)
fp16_cv = np.std(fp16_latencies) / np.mean(fp16_latencies)
print(f" FP16 CV: {fp16_cv:.4f}")
# INT4 Jitter
env.reset(task_id="static_workload")
int4_latencies = []
for _ in range(50):
obs = env.step(ServeAction(quantization_tier=QuantizationTier.INT4.value, batch_cap=200))
int4_latencies.append(obs.p50_ttft_ms)
int4_cv = np.std(int4_latencies) / np.mean(int4_latencies)
print(f" INT4 CV: {int4_cv:.4f}")
# Assert INT4 has notably higher jitter
assert int4_cv > fp16_cv, f"INT4 Jitter ({int4_cv:.4f}) must be > FP16 Jitter ({fp16_cv:.4f})"
print("[PASS] Quantization Jitter verified.")
def test_thermal_throttling():
print("[INFO] Testing Thermal Throttling Trigger...")
env = LLMServeEnvironment(seed=42)
env.reset(task_id="static_workload")
# Run 100 steps of low load
for i in range(100):
env.step(ServeAction(batch_cap=10))
obs_normal = env.step(ServeAction(batch_cap=10))
assert not obs_normal.metadata["is_throttled"], "Should not be throttled yet"
# Run 120 steps at low batch_cap to force queue growth (utilization)
# Trigger requires step_index > 100
for _ in range(120):
obs = env.step(ServeAction(batch_cap=512))
print(f" Step 120: Throttled={obs.metadata['is_throttled']}")
assert obs.metadata['is_throttled'], "Thermal throttling should be active"
print("[SUCCESS] Thermal Throttling Verified.")
def test_priority_preemption():
print("[INFO] Testing Priority Preemption...")
env = LLMServeEnvironment(seed=42)
# TASK_ID affects alpha, but here we check preemption
# We need a workload that fills the cache.
# We use a very small batch_cap to force queue growth
env.reset(task_id="adversarial_multitenant")
preemption_triggered = False
for i in range(40):
# Small batch_cap=2 forces queue to grow by ~178 per step (arrival is 180)
# queue_depth * 512 / (16000 * 0.1) > 0.95
# queue_depth * 512 / 1600 > 0.95 => queue_depth > 3
obs = env.step(ServeAction(priority_routing=True, kv_budget_fraction=0.1, batch_cap=2))
if obs.metadata["preemption_events"] > 0:
preemption_triggered = True
print(f" Step {i}: Preemption Triggered! Events: {obs.metadata['preemption_events']}")
break
assert preemption_triggered, "Priority routing should trigger preemption when cache is full"
print("[SUCCESS] Priority Preemption Verified.")
def test_speculative_acceptance():
print("[INFO] Testing Speculative Alpha (Chat vs API)...")
env = LLMServeEnvironment(seed=42)
# Chat Task
env.reset(task_id="static_workload")
obs_chat = env.step(ServeAction(speculation_depth=4))
# API Task
env.reset(task_id="adversarial_multitenant")
obs_api = env.step(ServeAction(speculation_depth=4))
print(f" Chat Alpha: {obs_chat.spec_acceptance_rate:.4f}")
print(f" API Alpha: {obs_api.spec_acceptance_rate:.4f}")
assert obs_chat.spec_acceptance_rate > obs_api.spec_acceptance_rate, "Chat should have higher acceptance than API"
print("[SUCCESS] Speculative Alpha Verified.")
if __name__ == "__main__":
try:
test_quantization_jitter()
test_thermal_throttling()
test_priority_preemption()
test_speculative_acceptance()
print("\n[ALL TESTS PASSED] Physical Binary Triggers are fully functional.")
except Exception as e:
print(f"\n[FAIL] Trigger Verification Failed: {e}")
sys.exit(1)
|