Spaces:
Sleeping
Sleeping
File size: 9,986 Bytes
d9c2197 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """TokenDance — Master-Mirror Storage for collective KV cache sharing.
Based on TokenDance (arXiv:2604.03143, Apr 2026): "Collective KV Cache
Sharing for Multi-Agent Inference."
Idea: instead of storing N independent KV caches for N agents, store one
"master" KV cache and (N-1) sparse diffs ("mirrors"). When agents share a
common prefix and diverge only on a small subset of blocks, the diff is
mostly zero — block-sparse storage compresses it 11–17x.
Storage layout:
master_cache[m_id] full KV blocks for master agent
mirrors[a_id] = SparseKVDiff( sparse delta vs master:
block_indices: indices of blocks that differ
diff_values: the per-block deltas at those indices
)
Reconstruction:
full_kv[a_id] = master_cache[m_id].copy()
full_kv[a_id][block_indices] += diff_values
Diff threshold (default 1e-4) controls sparsity: blocks with L2 norm of
delta below threshold are dropped (reconstruction within tolerance).
Collective reuse step (All-Gather pattern): given a new round's shared
context, push the update once to the master and re-derive all mirror
diffs. Cost is O(blocks) regardless of agent count.
Pure numpy. No GPU dependency. Graceful degradation principle.
"""
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
@dataclass
class SparseKVDiff:
"""Sparse delta of an agent's KV blocks vs the master agent's blocks.
Only blocks whose L2 norm of the delta exceeds the diff threshold are
stored. Reconstruction adds these deltas back to the corresponding
master blocks; all other blocks are byte-identical to the master.
"""
block_indices: np.ndarray # shape (n_diff_blocks,) int
diff_values: np.ndarray # shape (n_diff_blocks, *block_shape) float
total_blocks: int # original number of blocks (for reconstruction)
threshold: float = 1e-4
@property
def n_diff_blocks(self) -> int:
return int(self.block_indices.shape[0])
@property
def sparsity(self) -> float:
if self.total_blocks == 0:
return 0.0
return 1.0 - self.n_diff_blocks / self.total_blocks
class TokenDanceStorage:
"""Master-Mirror diff storage for multi-agent KV cache.
Stores 1 full Master KV cache + (N-1) block-sparse diffs.
Achieves 11-17x compression vs storing N full KV caches when agents
share large prefixes (typical in 5-agent RAG/Critic pipelines).
Based on: TokenDance (arXiv:2604.03143, Apr 2026).
"""
def __init__(self, diff_threshold: float = 1e-4):
self.diff_threshold: float = diff_threshold
self.master_id: str | None = None
self.master_cache: dict[str, np.ndarray] = {}
self.mirrors: dict[str, SparseKVDiff] = {}
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
def register_master(self, agent_id: str, kv_blocks: np.ndarray) -> None:
"""Register the master agent. The first call sets the reference KV.
Calling this again with a different agent_id replaces the master
and clears mirror state — all mirrors must be re-registered.
"""
if kv_blocks.ndim < 2:
raise ValueError(
f"kv_blocks must be at least 2D (n_blocks, ...); got shape {kv_blocks.shape}"
)
if self.master_id is not None and self.master_id != agent_id:
self.mirrors.clear()
self.master_cache.clear()
self.master_id = agent_id
self.master_cache[agent_id] = kv_blocks.copy()
def register_mirror(self, agent_id: str, kv_blocks: np.ndarray) -> SparseKVDiff:
"""Compute and store a sparse diff vs the master.
Only blocks whose per-block L2 norm of the delta exceeds
self.diff_threshold are kept; the rest are treated as identical.
"""
if self.master_id is None:
raise RuntimeError("register_master() must be called before register_mirror()")
master = self.master_cache[self.master_id]
if kv_blocks.shape != master.shape:
raise ValueError(
f"kv_blocks shape {kv_blocks.shape} must match master shape {master.shape}"
)
delta = kv_blocks - master
# Per-block L2 norm collapses all non-block dims into a single scalar.
flat = delta.reshape(delta.shape[0], -1)
per_block_norm = np.linalg.norm(flat, axis=1)
diff_mask = per_block_norm > self.diff_threshold
diff_indices = np.flatnonzero(diff_mask)
diff = SparseKVDiff(
block_indices=diff_indices.astype(np.int64),
diff_values=delta[diff_indices].copy() if diff_indices.size else np.empty(
(0,) + master.shape[1:], dtype=delta.dtype
),
total_blocks=master.shape[0],
threshold=self.diff_threshold,
)
self.mirrors[agent_id] = diff
return diff
def reconstruct(self, agent_id: str) -> np.ndarray:
"""Reconstruct the full KV cache for an agent."""
if self.master_id is None:
raise RuntimeError("No master registered")
if agent_id == self.master_id:
return self.master_cache[self.master_id].copy()
if agent_id not in self.mirrors:
raise KeyError(f"Unknown agent_id: {agent_id}")
diff = self.mirrors[agent_id]
out = self.master_cache[self.master_id].copy()
if diff.n_diff_blocks > 0:
out[diff.block_indices] = out[diff.block_indices] + diff.diff_values
return out
def compression_ratio(self) -> float:
"""Returns (sum of full per-agent block counts) / (master + diffs)."""
if self.master_id is None or not self.master_cache:
return 1.0
master_blocks = self.master_cache[self.master_id].shape[0]
n_agents = 1 + len(self.mirrors)
full_blocks = n_agents * master_blocks
stored_blocks = master_blocks + sum(d.n_diff_blocks for d in self.mirrors.values())
if stored_blocks == 0:
return float(n_agents)
return full_blocks / stored_blocks
def collective_reuse_step(
self,
agent_ids: list[str],
shared_blocks: np.ndarray,
) -> dict[str, int]:
"""All-Gather pattern: apply a shared-context update across agents.
Given a batch of new shared blocks (e.g. a freshly retrieved
context), append them to the master once and re-derive each
mirror's sparsity against the extended master.
The cost is O(master_blocks + total_diff_blocks) — paid once
regardless of agent count. The return value is per-agent diff
counts after the update for telemetry.
"""
if self.master_id is None:
raise RuntimeError("No master registered")
if shared_blocks.ndim < 2:
raise ValueError("shared_blocks must be at least 2D")
master = self.master_cache[self.master_id]
extended_master = np.concatenate([master, shared_blocks], axis=0)
self.master_cache[self.master_id] = extended_master
# Mirrors need to be extended to match the new master length.
# We assume agents adopt the shared blocks exactly (i.e. shared
# blocks are zero-diff for the mirrors). New mirror blocks are
# therefore identical to the appended master tail.
diff_counts: dict[str, int] = {self.master_id: 0}
for aid in agent_ids:
if aid == self.master_id:
continue
existing = self.mirrors.get(aid)
if existing is None:
# New mirror: identical to extended master so far.
self.mirrors[aid] = SparseKVDiff(
block_indices=np.empty((0,), dtype=np.int64),
diff_values=np.empty(
(0,) + extended_master.shape[1:], dtype=extended_master.dtype
),
total_blocks=extended_master.shape[0],
threshold=self.diff_threshold,
)
else:
# Pre-existing diffs unchanged; total_blocks bumps to new length.
self.mirrors[aid] = SparseKVDiff(
block_indices=existing.block_indices,
diff_values=existing.diff_values,
total_blocks=extended_master.shape[0],
threshold=existing.threshold,
)
diff_counts[aid] = self.mirrors[aid].n_diff_blocks
return diff_counts
# ------------------------------------------------------------------ #
# Introspection #
# ------------------------------------------------------------------ #
def stats(self) -> dict[str, float | int]:
master_blocks = (
self.master_cache[self.master_id].shape[0]
if self.master_id is not None
else 0
)
diff_blocks_total = sum(d.n_diff_blocks for d in self.mirrors.values())
return {
"master_id": self.master_id or "",
"master_blocks": master_blocks,
"n_mirrors": len(self.mirrors),
"diff_blocks_total": diff_blocks_total,
"compression_ratio": self.compression_ratio(),
"diff_threshold": self.diff_threshold,
}
def __repr__(self) -> str: # pragma: no cover - cosmetic
s = self.stats()
return (
f"TokenDanceStorage(master={s['master_id']!r}, "
f"master_blocks={s['master_blocks']}, mirrors={s['n_mirrors']}, "
f"diff_blocks={s['diff_blocks_total']}, "
f"compression={s['compression_ratio']:.2f}x, "
f"threshold={s['diff_threshold']:.0e})"
)
|