contextforge-demo / tests /test_speculative_coordinator.py
Pablo
feat: APOHARA: Context Forge V5 — synthesis + rebrand complete
cf0a8ed
"""Tests for SpeculativeCoordinator — TASK-003.
Tests cover:
- Config dataclass initialization and defaults
- Role-based viability checking (is_speculative_viable)
- Draft buffering (submit_draft) in both sync and overlapped modes
- verify_and_commit acceptance sampling
- estimate_speedup mathematical correctness
- Edge case: empty draft tokens
"""
import asyncio
import math
import random
from unittest.mock import MagicMock
import pytest
from apohara_context_forge.decoding.speculative_coordinator import (
SpeculativeConfig,
SpeculativeCoordinator,
SpeculativeResult,
)
class TestSpeculativeConfig:
"""Tests for SpeculativeConfig dataclass."""
def test_default_values(self):
"""SpeculativeConfig has correct defaults."""
config = SpeculativeConfig()
assert config.draft_agent_roles == frozenset({"retriever", "reranker"})
assert config.target_agent_roles == frozenset({"responder", "critic"})
assert config.max_draft_tokens == 8
assert config.acceptance_threshold == 0.9
assert config.enable_overlapped is True
assert config.min_stability_rho == 0.8
def test_custom_values(self):
"""Custom values are stored correctly."""
config = SpeculativeConfig(
max_draft_tokens=16,
acceptance_threshold=0.95,
enable_overlapped=False,
min_stability_rho=0.6,
)
assert config.max_draft_tokens == 16
assert config.acceptance_threshold == 0.95
assert config.enable_overlapped is False
assert config.min_stability_rho == 0.6
class TestSpeculativeCoordinator:
"""Tests for SpeculativeCoordinator."""
def test_is_speculative_viable_draft_role_ok(self):
"""Draft agent with allowed role returns True."""
coordinator = SpeculativeCoordinator()
# "retriever-0" extracts role "retriever" which is in draft roles.
assert coordinator.is_speculative_viable("retriever-0", "responder-0") is True
def test_is_speculative_viable_target_role_ok(self):
"""Target agent with allowed role returns True."""
coordinator = SpeculativeCoordinator()
# "responder-1" extracts role "responder" which is in target roles.
assert coordinator.is_speculative_viable("retriever-0", "responder-0") is True
def test_is_speculative_viable_wrong_draft_role(self):
"""Draft agent with disallowed role returns False."""
coordinator = SpeculativeCoordinator()
# "responder" role not in draft roles.
result = coordinator.is_speculative_viable("responder-0", "responder-0")
assert result is False
def test_is_speculative_viable_wrong_target_role(self):
"""Target agent with disallowed role returns False."""
coordinator = SpeculativeCoordinator()
# "retriever" role not in target roles.
result = coordinator.is_speculative_viable("retriever-0", "retriever-0")
assert result is False
def test_is_speculative_viable_rho_check(self):
"""rho above threshold blocks speculative decoding."""
mock_qc = MagicMock()
mock_qc.current_rho = MagicMock(return_value=0.9)
config = SpeculativeConfig(min_stability_rho=0.8)
coordinator = SpeculativeCoordinator(config=config, queueing_controller=mock_qc)
# rho=0.9 >= min_stability_rho=0.8 → blocked.
result = coordinator.is_speculative_viable("retriever-0", "responder-0")
assert result is False
def test_is_speculative_viable_rho_below_threshold(self):
"""rho below threshold allows speculative decoding."""
mock_qc = MagicMock()
mock_qc.current_rho = MagicMock(return_value=0.5)
config = SpeculativeConfig(min_stability_rho=0.8)
coordinator = SpeculativeCoordinator(config=config, queueing_controller=mock_qc)
# rho=0.5 < min_stability_rho=0.8 → allowed.
result = coordinator.is_speculative_viable("retriever-0", "responder-0")
assert result is True
@pytest.mark.asyncio
async def test_submit_draft_sync_mode(self):
"""submit_draft buffers draft in sync (non-overlapped) mode."""
config = SpeculativeConfig(enable_overlapped=False)
coordinator = SpeculativeCoordinator(config=config)
draft_tokens = [101, 202, 303]
await coordinator.submit_draft(draft_tokens, "responder-0", step=1)
assert coordinator._current_draft == ("responder-0", draft_tokens)
@pytest.mark.asyncio
async def test_submit_draft_overlapped_mode(self):
"""submit_draft enqueues draft when overlapped mode is enabled."""
config = SpeculativeConfig(enable_overlapped=True)
coordinator = SpeculativeCoordinator(config=config)
draft_tokens = [101, 202, 303]
await coordinator.submit_draft(draft_tokens, "responder-0", step=1)
# Should be in the queue.
got = coordinator._draft_queue.get_nowait()
assert got == ("responder-0", draft_tokens)
@pytest.mark.asyncio
async def test_verify_and_commit_empty_draft(self):
"""Empty draft_tokens returns SpeculativeResult with all empty fields."""
coordinator = SpeculativeCoordinator()
result = await coordinator.verify_and_commit(
target_verification_logprobs=[], draft_tokens=[]
)
assert result.draft_tokens == []
assert result.accepted_tokens == []
assert result.rejected_at_position == -1
assert result.acceptance_rate == 1.0
assert result.decode_speedup_estimate == 1.0
assert result.overlapped_next_draft is None
@pytest.mark.asyncio
async def test_verify_and_commit_all_accepted(self):
"""
When random <= ratio for all tokens, all are accepted.
Uses fixed seed so result is deterministic.
"""
config = SpeculativeConfig(acceptance_threshold=0.9)
coordinator = SpeculativeCoordinator(config=config)
# High logprobs (close to 0) → high probs → ratio near 1.0.
# With acceptance_threshold=0.9, ratio = p_i / 0.9 ≈ 1.0.
# Seeded random=0.5 ≤ 1.0 → accept.
random.seed(0)
draft_tokens = [10, 20, 30]
logprobs = [0.0, 0.0, 0.0] # p ≈ 1.0 each
result = await coordinator.verify_and_commit(logprobs, draft_tokens)
# All should be accepted since ratio ≈ 1.0 and random(0.5) < 1.0.
assert result.accepted_tokens == draft_tokens
assert result.rejected_at_position == -1
assert result.acceptance_rate == 1.0
@pytest.mark.asyncio
async def test_verify_and_commit_rejection(self):
"""
When random > ratio the token is rejected at that position.
With very low logprobs the ratio is near 0, so rejection is likely.
"""
config = SpeculativeConfig(acceptance_threshold=0.9)
coordinator = SpeculativeCoordinator(config=config)
# Very negative logprobs → very low probs → ratio ≈ 0.
# random() will almost certainly be > ratio → rejection at position 0.
random.seed(42)
draft_tokens = [10, 20, 30]
logprobs = [-10.0, -10.0, -10.0] # p ≈ 4.5e-5
result = await coordinator.verify_and_commit(logprobs, draft_tokens)
# Should reject at position 0 since ratio is tiny.
assert result.rejected_at_position == 0
assert len(result.accepted_tokens) == 0
@pytest.mark.asyncio
async def test_verify_and_commit_partial_acceptance(self):
"""
Some tokens accepted, then rejection occurs.
Uses intermediate logprobs for mixed outcome.
"""
config = SpeculativeConfig(acceptance_threshold=0.9)
coordinator = SpeculativeCoordinator(config=config)
random.seed(12345)
draft_tokens = [10, 20, 30, 40, 50]
# Tuned logprobs so first 2 accept, 3rd rejects.
# logprob=-0.1 → p≈0.90, ratio=1.0 → accept if random ≤ 1.0
# logprob=-2.3 → p≈0.10, ratio≈0.11 → reject unless random < 0.11
logprobs = [-0.1, -0.1, -2.3, 0.0, 0.0]
result = await coordinator.verify_and_commit(logprobs, draft_tokens)
# First two should be accepted (random values ≤ 1.0).
assert len(result.accepted_tokens) >= 2
# If rejected, rejected_at_position reflects first failure.
assert result.rejected_at_position == -1 or result.rejected_at_position >= 2
@pytest.mark.asyncio
async def test_verify_and_commit_overlapped_next_draft(self):
"""
When enable_overlapped=True and queue has a prefetched draft,
overlapped_next_draft is populated in the result.
"""
config = SpeculativeConfig(enable_overlapped=True)
coordinator = SpeculativeCoordinator(config=config)
# Pre-load a draft into the queue.
prefetched_tokens = [999, 888, 777]
await coordinator._draft_queue.put(("responder-1", prefetched_tokens))
result = await coordinator.verify_and_commit(
target_verification_logprobs=[0.0, 0.0],
draft_tokens=[10, 20],
)
assert result.overlapped_next_draft == prefetched_tokens
@pytest.mark.asyncio
async def test_verify_and_commit_no_overlapped_next_draft(self):
"""
When queue is empty, overlapped_next_draft is None even if enabled.
"""
config = SpeculativeConfig(enable_overlapped=True)
coordinator = SpeculativeCoordinator(config=config)
# Queue is empty.
result = await coordinator.verify_and_commit(
target_verification_logprobs=[0.0],
draft_tokens=[10],
)
assert result.overlapped_next_draft is None
def test_estimate_speedup_max_acceptance(self):
"""100% acceptance → maximum speedup (k+1 tokens per step)."""
coordinator = SpeculativeCoordinator()
k = 8
speedup = coordinator.estimate_speedup(1.0, max_draft_tokens=k)
assert math.isclose(speedup, k + 1, rel_tol=1e-9)
def test_estimate_speedup_zero_acceptance(self):
"""0% acceptance → no speedup (only fallback token, speedup = 1.0)."""
coordinator = SpeculativeCoordinator()
speedup = coordinator.estimate_speedup(0.0, max_draft_tokens=8)
assert speedup == 1.0
def test_estimate_speedup_090_acceptance_k8(self):
"""
From the spec: acceptance_rate=0.9, k=8 → speedup ≈ 5.7x.
E[tokens] = (1 - r^(k+1)) / (1 - r)
= (1 - 0.9^9) / (1 - 0.9)
= (1 - 0.3874) / 0.1
≈ 6.126
"""
coordinator = SpeculativeCoordinator()
speedup = coordinator.estimate_speedup(0.9, max_draft_tokens=8)
expected = (1.0 - (0.9 ** 9)) / 0.1
assert math.isclose(speedup, expected, rel_tol=1e-9)
def test_estimate_speedup_out_of_range(self):
"""Acceptance rate outside [0,1] returns 1.0 (no speedup)."""
coordinator = SpeculativeCoordinator()
assert coordinator.estimate_speedup(-0.5, max_draft_tokens=8) == 1.0
assert coordinator.estimate_speedup(1.5, max_draft_tokens=8) == 1.0
def test_role_from_agent_id(self):
"""_role_from_agent_id extracts role from agent_id suffix."""
coordinator = SpeculativeCoordinator()
assert coordinator._role_from_agent_id("retriever-0") == "retriever"
assert coordinator._role_from_agent_id("responder-1") == "responder"
assert coordinator._role_from_agent_id("agent:reranker-2") == "reranker"
assert coordinator._role_from_agent_id("worker:critic-0") == "critic"