File size: 9,986 Bytes
d9c2197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""TokenDance — Master-Mirror Storage for collective KV cache sharing.

Based on TokenDance (arXiv:2604.03143, Apr 2026): "Collective KV Cache
Sharing for Multi-Agent Inference."

Idea: instead of storing N independent KV caches for N agents, store one
"master" KV cache and (N-1) sparse diffs ("mirrors"). When agents share a
common prefix and diverge only on a small subset of blocks, the diff is
mostly zero — block-sparse storage compresses it 11–17x.

Storage layout:
    master_cache[m_id]                     full KV blocks for master agent
    mirrors[a_id] = SparseKVDiff(          sparse delta vs master:
        block_indices: indices of blocks that differ
        diff_values:   the per-block deltas at those indices
    )

Reconstruction:
    full_kv[a_id] = master_cache[m_id].copy()
    full_kv[a_id][block_indices] += diff_values

Diff threshold (default 1e-4) controls sparsity: blocks with L2 norm of
delta below threshold are dropped (reconstruction within tolerance).

Collective reuse step (All-Gather pattern): given a new round's shared
context, push the update once to the master and re-derive all mirror
diffs. Cost is O(blocks) regardless of agent count.

Pure numpy. No GPU dependency. Graceful degradation principle.
"""
from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np


@dataclass
class SparseKVDiff:
    """Sparse delta of an agent's KV blocks vs the master agent's blocks.

    Only blocks whose L2 norm of the delta exceeds the diff threshold are
    stored. Reconstruction adds these deltas back to the corresponding
    master blocks; all other blocks are byte-identical to the master.
    """

    block_indices: np.ndarray  # shape (n_diff_blocks,) int
    diff_values: np.ndarray    # shape (n_diff_blocks, *block_shape) float
    total_blocks: int          # original number of blocks (for reconstruction)
    threshold: float = 1e-4

    @property
    def n_diff_blocks(self) -> int:
        return int(self.block_indices.shape[0])

    @property
    def sparsity(self) -> float:
        if self.total_blocks == 0:
            return 0.0
        return 1.0 - self.n_diff_blocks / self.total_blocks


class TokenDanceStorage:
    """Master-Mirror diff storage for multi-agent KV cache.

    Stores 1 full Master KV cache + (N-1) block-sparse diffs.
    Achieves 11-17x compression vs storing N full KV caches when agents
    share large prefixes (typical in 5-agent RAG/Critic pipelines).

    Based on: TokenDance (arXiv:2604.03143, Apr 2026).
    """

    def __init__(self, diff_threshold: float = 1e-4):
        self.diff_threshold: float = diff_threshold
        self.master_id: str | None = None
        self.master_cache: dict[str, np.ndarray] = {}
        self.mirrors: dict[str, SparseKVDiff] = {}

    # ------------------------------------------------------------------ #
    # Public API                                                          #
    # ------------------------------------------------------------------ #

    def register_master(self, agent_id: str, kv_blocks: np.ndarray) -> None:
        """Register the master agent. The first call sets the reference KV.

        Calling this again with a different agent_id replaces the master
        and clears mirror state — all mirrors must be re-registered.
        """
        if kv_blocks.ndim < 2:
            raise ValueError(
                f"kv_blocks must be at least 2D (n_blocks, ...); got shape {kv_blocks.shape}"
            )
        if self.master_id is not None and self.master_id != agent_id:
            self.mirrors.clear()
            self.master_cache.clear()
        self.master_id = agent_id
        self.master_cache[agent_id] = kv_blocks.copy()

    def register_mirror(self, agent_id: str, kv_blocks: np.ndarray) -> SparseKVDiff:
        """Compute and store a sparse diff vs the master.

        Only blocks whose per-block L2 norm of the delta exceeds
        self.diff_threshold are kept; the rest are treated as identical.
        """
        if self.master_id is None:
            raise RuntimeError("register_master() must be called before register_mirror()")
        master = self.master_cache[self.master_id]
        if kv_blocks.shape != master.shape:
            raise ValueError(
                f"kv_blocks shape {kv_blocks.shape} must match master shape {master.shape}"
            )

        delta = kv_blocks - master
        # Per-block L2 norm collapses all non-block dims into a single scalar.
        flat = delta.reshape(delta.shape[0], -1)
        per_block_norm = np.linalg.norm(flat, axis=1)
        diff_mask = per_block_norm > self.diff_threshold
        diff_indices = np.flatnonzero(diff_mask)

        diff = SparseKVDiff(
            block_indices=diff_indices.astype(np.int64),
            diff_values=delta[diff_indices].copy() if diff_indices.size else np.empty(
                (0,) + master.shape[1:], dtype=delta.dtype
            ),
            total_blocks=master.shape[0],
            threshold=self.diff_threshold,
        )
        self.mirrors[agent_id] = diff
        return diff

    def reconstruct(self, agent_id: str) -> np.ndarray:
        """Reconstruct the full KV cache for an agent."""
        if self.master_id is None:
            raise RuntimeError("No master registered")
        if agent_id == self.master_id:
            return self.master_cache[self.master_id].copy()
        if agent_id not in self.mirrors:
            raise KeyError(f"Unknown agent_id: {agent_id}")

        diff = self.mirrors[agent_id]
        out = self.master_cache[self.master_id].copy()
        if diff.n_diff_blocks > 0:
            out[diff.block_indices] = out[diff.block_indices] + diff.diff_values
        return out

    def compression_ratio(self) -> float:
        """Returns (sum of full per-agent block counts) / (master + diffs)."""
        if self.master_id is None or not self.master_cache:
            return 1.0
        master_blocks = self.master_cache[self.master_id].shape[0]
        n_agents = 1 + len(self.mirrors)
        full_blocks = n_agents * master_blocks
        stored_blocks = master_blocks + sum(d.n_diff_blocks for d in self.mirrors.values())
        if stored_blocks == 0:
            return float(n_agents)
        return full_blocks / stored_blocks

    def collective_reuse_step(
        self,
        agent_ids: list[str],
        shared_blocks: np.ndarray,
    ) -> dict[str, int]:
        """All-Gather pattern: apply a shared-context update across agents.

        Given a batch of new shared blocks (e.g. a freshly retrieved
        context), append them to the master once and re-derive each
        mirror's sparsity against the extended master.

        The cost is O(master_blocks + total_diff_blocks) — paid once
        regardless of agent count. The return value is per-agent diff
        counts after the update for telemetry.
        """
        if self.master_id is None:
            raise RuntimeError("No master registered")
        if shared_blocks.ndim < 2:
            raise ValueError("shared_blocks must be at least 2D")

        master = self.master_cache[self.master_id]
        extended_master = np.concatenate([master, shared_blocks], axis=0)
        self.master_cache[self.master_id] = extended_master

        # Mirrors need to be extended to match the new master length.
        # We assume agents adopt the shared blocks exactly (i.e. shared
        # blocks are zero-diff for the mirrors). New mirror blocks are
        # therefore identical to the appended master tail.
        diff_counts: dict[str, int] = {self.master_id: 0}
        for aid in agent_ids:
            if aid == self.master_id:
                continue
            existing = self.mirrors.get(aid)
            if existing is None:
                # New mirror: identical to extended master so far.
                self.mirrors[aid] = SparseKVDiff(
                    block_indices=np.empty((0,), dtype=np.int64),
                    diff_values=np.empty(
                        (0,) + extended_master.shape[1:], dtype=extended_master.dtype
                    ),
                    total_blocks=extended_master.shape[0],
                    threshold=self.diff_threshold,
                )
            else:
                # Pre-existing diffs unchanged; total_blocks bumps to new length.
                self.mirrors[aid] = SparseKVDiff(
                    block_indices=existing.block_indices,
                    diff_values=existing.diff_values,
                    total_blocks=extended_master.shape[0],
                    threshold=existing.threshold,
                )
            diff_counts[aid] = self.mirrors[aid].n_diff_blocks
        return diff_counts

    # ------------------------------------------------------------------ #
    # Introspection                                                       #
    # ------------------------------------------------------------------ #

    def stats(self) -> dict[str, float | int]:
        master_blocks = (
            self.master_cache[self.master_id].shape[0]
            if self.master_id is not None
            else 0
        )
        diff_blocks_total = sum(d.n_diff_blocks for d in self.mirrors.values())
        return {
            "master_id": self.master_id or "",
            "master_blocks": master_blocks,
            "n_mirrors": len(self.mirrors),
            "diff_blocks_total": diff_blocks_total,
            "compression_ratio": self.compression_ratio(),
            "diff_threshold": self.diff_threshold,
        }

    def __repr__(self) -> str:  # pragma: no cover - cosmetic
        s = self.stats()
        return (
            f"TokenDanceStorage(master={s['master_id']!r}, "
            f"master_blocks={s['master_blocks']}, mirrors={s['n_mirrors']}, "
            f"diff_blocks={s['diff_blocks_total']}, "
            f"compression={s['compression_ratio']:.2f}x, "
            f"threshold={s['diff_threshold']:.0e})"
        )