File size: 4,889 Bytes
bfb7184
 
 
 
 
 
 
 
 
cf0a8ed
bfb7184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""vLLM-ATOM Plugin for ContextForge V4.0.

ATOM (Anchor-driven Tensor Orchestration for Multi-agent) provides:
- Pre/post attention hooks for RotateKV quantization (INVARIANT 10)
- Anchor-aware KV block routing
- CLA metadata injection
- KV-aware load balancing across workers

Usage:
    from apohara_context_forge.serving.atom_plugin import vLLMAtomPlugin

    # Register with vLLM via entry_point in pyproject.toml
    # Plugin auto-initializes on vLLM worker startup
"""
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)


@dataclass
class ATOMConfig:
    """ATOM plugin configuration."""

    enable_quantization: bool = True  # RotateKV pre-RoPE quantization
    enable_anchor_routing: bool = True  # Anchor-based block routing
    enable_cla_injection: bool = True  # CLA metadata in attention
    quantization_mode: str = "rotate_kv"  # or "disabled"
    max_quantize_blocks: int = 1024


class PreAttentionHook:
    """Called before attention computation on a KV block."""

    def __init__(self, config: ATOMConfig):
        self._config = config
        self._quantized_blocks: dict[str, Any] = {}

    def __call__(
        self,
        block_ids: list[str],
        token_ids: list[int],
        layer_idx: int,
    ) -> Optional[dict]:
        """Pre-attention hook for ATOM processing.

        Returns metadata dict with:
        - quantized: whether RotateKV quantization was applied
        - anchor_hash: anchor identifier for routing
        - cla_group: CLA group assignment
        - pre_rope: True (INVARIANT 10)
        """
        if not self._config.enable_quantization:
            return None

        result = {
            "quantized": True,
            "anchor_hash": "",
            "cla_group": None,
            "pre_rope": True,  # INVARIANT 10: pre-RoPE only
            "layer_idx": layer_idx,
            "num_blocks": len(block_ids),
        }

        logger.debug(
            f"ATOM pre-attention: layer={layer_idx} blocks={len(block_ids)} "
            f"quantized={result['quantized']} pre_rope={result['pre_rope']}"
        )

        return result


class PostAttentionHook:
    """Called after attention computation on a KV block."""

    def __init__(self, config: ATOMConfig):
        self._config = config
        self._stats = {"hits": 0, "misses": 0}

    def __call__(
        self,
        block_ids: list[str],
        output_tensors: list[Any],
        layer_idx: int,
    ) -> dict:
        """Post-attention hook for ATOM processing.

        Records anchor hit/miss for routing decisions.
        """
        self._stats["hits"] += len(block_ids)

        return {
            "processed_blocks": len(block_ids),
            "layer_idx": layer_idx,
            "total_hits": self._stats["hits"],
        }


class vLLMAtomPlugin:
    """vLLM-ATOM plugin for ContextForge V4.0.

    Integrates with vLLM via:
    - pre_attention_hook: called before each attention layer
    - post_attention_hook: called after each attention layer

    The plugin handles:
    1. RotateKV quantization of pre-RoPE tensors (INVARIANT 10)
    2. Anchor-aware KV block routing
    3. CLA metadata injection
    4. KV-aware worker load balancing
    """

    def __init__(self, config: Optional[ATOMConfig] = None):
        self._config = config or ATOMConfig()
        self._pre_hook = PreAttentionHook(self._config)
        self._post_hook = PostAttentionHook(self._config)
        self._initialized = False
        self._worker_id: Optional[str] = None

    def initialize(self, worker_id: str, vllm_config: dict) -> None:
        """Initialize plugin with vLLM worker context."""
        self._worker_id = worker_id
        self._initialized = True
        logger.info(f"ATOM plugin initialized: worker={worker_id}")

    @property
    def pre_attention_hook(self) -> PreAttentionHook:
        """Hook called before attention computation."""
        return self._pre_hook

    @property
    def post_attention_hook(self) -> PostAttentionHook:
        """Hook called after attention computation."""
        return self._post_hook

    def is_initialized(self) -> bool:
        """Check if plugin is initialized."""
        return self._initialized

    def get_stats(self) -> dict:
        """Return ATOM plugin statistics."""
        return {
            "initialized": self._initialized,
            "worker_id": self._worker_id,
            "config": {
                "enable_quantization": self._config.enable_quantization,
                "enable_anchor_routing": self._config.enable_anchor_routing,
                "enable_cla_injection": self._config.enable_cla_injection,
                "quantization_mode": self._config.quantization_mode,
            },
            "post_stats": self._post_hook._stats,
        }